Skip to content

Commit ce80bf8

Browse files
Enhance load_method to support optional planned memory allocator
- Updated the load_method signature to accept an optional runtime::HierarchicalAllocator parameter.
1 parent 19a3002 commit ce80bf8

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

extension/module/module.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
#include <executorch/extension/data_loader/mmap_data_loader.h>
1313
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1414
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
15+
#include <executorch/runtime/core/hierarchical_allocator.h>
1516
#include <executorch/runtime/platform/runtime.h>
17+
#include <memory>
1618

1719
/**
1820
* Unwrap a Result to obtain its value (direct object, not a pointer).
@@ -178,34 +180,42 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
178180

179181
runtime::Error Module::load_method(
180182
const std::string& method_name,
181-
torch::executor::EventTracer* event_tracer) {
183+
torch::executor::EventTracer* event_tracer,
184+
runtime::HierarchicalAllocator* planned_memory_allocator) {
182185
if (!is_method_loaded(method_name)) {
183186
ET_CHECK_OK_OR_RETURN_ERROR(load());
184187

185188
MethodHolder method_holder;
189+
runtime::HierarchicalAllocator* planned_memory = nullptr;
186190

187-
const auto method_metadata =
188-
ET_UNWRAP(program_->method_meta(method_name.c_str()));
189-
const auto planned_buffersCount =
190-
method_metadata.num_memory_planned_buffers();
191-
method_holder.planned_buffers.reserve(planned_buffersCount);
192-
method_holder.planned_spans.reserve(planned_buffersCount);
191+
// we were not given a planned memory allocator, so we need to create one:
192+
if (planned_memory_allocator == nullptr) {
193+
const auto method_metadata =
194+
ET_UNWRAP(program_->method_meta(method_name.c_str()));
195+
const auto planned_buffersCount =
196+
method_metadata.num_memory_planned_buffers();
197+
method_holder.planned_buffers.reserve(planned_buffersCount);
198+
method_holder.planned_spans.reserve(planned_buffersCount);
193199

194-
for (auto index = 0; index < planned_buffersCount; ++index) {
195-
const auto buffer_size =
196-
method_metadata.memory_planned_buffer_size(index).get();
197-
method_holder.planned_buffers.emplace_back(buffer_size);
198-
method_holder.planned_spans.emplace_back(
199-
method_holder.planned_buffers.back().data(), buffer_size);
200+
for (auto index = 0; index < planned_buffersCount; ++index) {
201+
const auto buffer_size =
202+
method_metadata.memory_planned_buffer_size(index).get();
203+
method_holder.planned_buffers.emplace_back(buffer_size);
204+
method_holder.planned_spans.emplace_back(
205+
method_holder.planned_buffers.back().data(), buffer_size);
206+
}
207+
method_holder.planned_memory =
208+
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
209+
method_holder.planned_spans.data(),
210+
method_holder.planned_spans.size()));
211+
planned_memory = method_holder.planned_memory.get();
212+
} else {
213+
// we were given a planned memory allocator, so we use it:
214+
planned_memory = planned_memory_allocator;
200215
}
201-
method_holder.planned_memory =
202-
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
203-
method_holder.planned_spans.data(),
204-
method_holder.planned_spans.size()));
216+
205217
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
206-
memory_allocator_.get(),
207-
method_holder.planned_memory.get(),
208-
temp_allocator_.get());
218+
memory_allocator_.get(), planned_memory, temp_allocator_.get());
209219
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
210220
method_name.c_str(),
211221
method_holder.memory_manager.get(),

extension/module/module.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ class Module {
162162
ET_NODISCARD
163163
runtime::Error load_method(
164164
const std::string& method_name,
165-
torch::executor::EventTracer* event_tracer = nullptr);
165+
torch::executor::EventTracer* event_tracer = nullptr,
166+
runtime::HierarchicalAllocator* planned_memory_allocator = nullptr);
166167

167168
/**
168169
* Load the 'forward' method from the program and set up memory management if

0 commit comments

Comments
 (0)