Skip to content

Commit 8888c0d

Browse files
meta-emilianfacebook-github-bot
authored andcommitted
Adding per-method tracers to the module utility. Changing set_output_data_ptr to take in a method name. (#5279)
Summary: Pull Request resolved: #5279 * Adding per-method tracers to the executorch module utilty to be able to profile/trace methods individually * Enabling per-method output data pointers to be able to use per-method input/output memory planning through the module. Reviewed By: tarun292 Differential Revision: D62520386 fbshipit-source-id: 6287701183e664d68435c48d6ed2b566e3d10d93
1 parent 4053a18 commit 8888c0d

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

extension/module/module.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
123123
return result;
124124
}
125125

126-
runtime::Error Module::load_method(const std::string& method_name) {
126+
runtime::Error Module::load_method(
127+
const std::string& method_name,
128+
torch::executor::EventTracer* tracer) {
127129
if (!is_method_loaded(method_name)) {
128130
ET_CHECK_OK_OR_RETURN_ERROR(load());
129131

@@ -151,9 +153,7 @@ runtime::Error Module::load_method(const std::string& method_name) {
151153
method_holder.planned_memory.get(),
152154
temp_allocator_.get());
153155
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
154-
method_name.c_str(),
155-
method_holder.memory_manager.get(),
156-
event_tracer_.get()));
156+
method_name.c_str(), method_holder.memory_manager.get(), tracer));
157157
methods_.emplace(method_name, std::move(method_holder));
158158
}
159159
return runtime::Error::Ok;
@@ -185,10 +185,11 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
185185

186186
runtime::Error Module::set_output_data_ptr(
187187
runtime::EValue output_value,
188-
size_t output_index) {
189-
ET_CHECK_OK_OR_RETURN_ERROR(load_method("forward"));
188+
size_t output_index,
189+
const std::string& method_name) {
190+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
190191
auto& output_tensor = output_value.toTensor();
191-
auto& method = methods_.at("forward").method;
192+
auto& method = methods_.at(method_name).method;
192193
return method->set_output_data_ptr(
193194
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
194195
}

extension/module/module.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ class Module {
136136
* @returns An Error to indicate success or failure.
137137
*/
138138
ET_NODISCARD
139-
runtime::Error load_method(const std::string& method_name);
139+
runtime::Error load_method(
140+
const std::string& method_name,
141+
torch::executor::EventTracer* tracer = nullptr);
140142

141143
/**
142144
* Checks if a specific method is loaded.
@@ -318,7 +320,8 @@ class Module {
318320
*/
319321
runtime::Error set_output_data_ptr(
320322
runtime::EValue output_value,
321-
size_t output_index);
323+
size_t output_index,
324+
const std::string& method_name = "forward");
322325

323326
private:
324327
struct MethodHolder {

0 commit comments

Comments
 (0)