Skip to content

Commit 3f670a2

Browse files
Cleaned up methods, deprecated old interfaces.
1 parent 864f2cd commit 3f670a2

File tree

2 files changed

+35
-23
lines changed

2 files changed

+35
-23
lines changed

extension/module/module.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
#include <executorch/extension/data_loader/file_data_loader.h>
1212
#include <executorch/extension/data_loader/mmap_data_loader.h>
1313
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
14-
#include <executorch/runtime/core/hierarchical_allocator.h>
1514
#include <executorch/runtime/platform/runtime.h>
16-
#include <memory>
1715

1816
/**
1917
* Unwrap a Result to obtain its value (direct object, not a pointer).
@@ -127,42 +125,39 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
127125

128126
runtime::Error Module::load_method(
129127
const std::string& method_name,
130-
torch::executor::EventTracer* event_tracer,
131-
runtime::HierarchicalAllocator* planned_memory_allocator) {
128+
runtime::HierarchicalAllocator* planned_memory,
129+
torch::executor::EventTracer* event_tracer) {
132130
if (!is_method_loaded(method_name)) {
133131
ET_CHECK_OK_OR_RETURN_ERROR(load());
134132

135133
MethodHolder method_holder;
136-
runtime::HierarchicalAllocator* planned_memory = nullptr;
137134

138-
// we were not given a planned memory allocator, so we need to create one:
139-
if (planned_memory_allocator == nullptr) {
135+
if (!planned_memory) {
140136
const auto method_metadata =
141-
ET_UNWRAP(program_->method_meta(method_name.c_str()));
142-
const auto planned_buffersCount =
137+
ET_UNWRAP(program_->method_meta(method_name.c_str()));
138+
const auto planned_buffers_count =
143139
method_metadata.num_memory_planned_buffers();
144-
method_holder.planned_buffers.reserve(planned_buffersCount);
145-
method_holder.planned_spans.reserve(planned_buffersCount);
140+
method_holder.planned_buffers.reserve(planned_buffers_count);
141+
method_holder.planned_spans.reserve(planned_buffers_count);
146142

147-
for (auto index = 0; index < planned_buffersCount; ++index) {
143+
for (auto index = 0; index < planned_buffers_count; ++index) {
148144
const auto buffer_size =
149145
method_metadata.memory_planned_buffer_size(index).get();
150146
method_holder.planned_buffers.emplace_back(buffer_size);
151147
method_holder.planned_spans.emplace_back(
152148
method_holder.planned_buffers.back().data(), buffer_size);
153149
}
154150
method_holder.planned_memory =
155-
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
156-
method_holder.planned_spans.data(),
157-
method_holder.planned_spans.size()));
151+
std::make_unique<runtime::HierarchicalAllocator>(
152+
runtime::Span(
153+
method_holder.planned_spans.data(),
154+
method_holder.planned_spans.size()));
158155
planned_memory = method_holder.planned_memory.get();
159-
} else {
160-
// we were given a planned memory allocator, so we use it:
161-
planned_memory = planned_memory_allocator;
162156
}
163-
164157
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
165-
memory_allocator_.get(), planned_memory, temp_allocator_.get());
158+
memory_allocator_.get(),
159+
planned_memory,
160+
temp_allocator_.get());
166161
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
167162
method_name.c_str(),
168163
method_holder.memory_manager.get(),

extension/module/module.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ class Module {
133133
* needed. The loaded method is cached to reuse the next time it's executed.
134134
*
135135
* @param[in] method_name The name of the method to load.
136+
* @param[in] planned_memory The memory-planned buffers to use for mutable
137+
* tensor data when executing a method.
136138
* @param[in] event_tracer Per-method event tracer to profile/trace methods
137139
* individually. When not given, the event tracer passed to the Module
138140
* constructor is used. Otherwise, this per-method event tracer takes
@@ -143,21 +145,36 @@ class Module {
143145
ET_NODISCARD
144146
runtime::Error load_method(
145147
const std::string& method_name,
146-
torch::executor::EventTracer* event_tracer = nullptr,
147-
runtime::HierarchicalAllocator* planned_memory_allocator = nullptr);
148+
runtime::HierarchicalAllocator* planned_memory = nullptr,
149+
torch::executor::EventTracer* event_tracer = nullptr);
150+
151+
ET_DEPRECATED ET_NODISCARD
152+
runtime::Error inline load_method(
153+
const std::string& method_name,
154+
torch::executor::EventTracer* event_tracer) {
155+
return load_method(method_name, nullptr, event_tracer);
156+
}
148157

149158
/**
150159
* Load the 'forward' method from the program and set up memory management if
151160
* needed. The loaded method is cached to reuse the next time it's executed.
152161
*
162+
* @param[in] planned_memory The memory-planned buffers to use for mutable
163+
* tensor data when executing the 'forward' method.
153164
* @param[in] event_tracer An event tracer used for tracking and logging
154165
* events.
155166
*
156167
* @returns An Error to indicate success or failure.
157168
*/
158169
ET_NODISCARD inline runtime::Error load_forward(
170+
runtime::HierarchicalAllocator* planned_memory = nullptr,
159171
torch::executor::EventTracer* event_tracer = nullptr) {
160-
return load_method("forward", event_tracer);
172+
return load_method("forward", planned_memory, event_tracer);
173+
}
174+
175+
ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward(
176+
torch::executor::EventTracer* event_tracer) {
177+
return load_forward(nullptr, event_tracer);
161178
}
162179

163180
/**

0 commit comments

Comments
 (0)