Skip to content

Commit 06d1584

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Enhance load_method to support optional planned memory allocator.
Differential Revision: D68875417
1 parent 4796da7 commit 06d1584

File tree

2 files changed

+42
-18
lines changed

2 files changed

+42
-18
lines changed

extension/module/module.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -125,32 +125,38 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
125125

126126
runtime::Error Module::load_method(
127127
const std::string& method_name,
128+
runtime::HierarchicalAllocator* planned_memory,
128129
torch::executor::EventTracer* event_tracer) {
129130
if (!is_method_loaded(method_name)) {
130131
ET_CHECK_OK_OR_RETURN_ERROR(load());
131132

132133
MethodHolder method_holder;
133-
const auto method_metadata =
134+
135+
if (!planned_memory) {
136+
const auto method_metadata =
134137
ET_UNWRAP(program_->method_meta(method_name.c_str()));
135-
const auto planned_buffersCount =
136-
method_metadata.num_memory_planned_buffers();
137-
method_holder.planned_buffers.reserve(planned_buffersCount);
138-
method_holder.planned_spans.reserve(planned_buffersCount);
139-
140-
for (auto index = 0; index < planned_buffersCount; ++index) {
141-
const auto buffer_size =
142-
method_metadata.memory_planned_buffer_size(index).get();
143-
method_holder.planned_buffers.emplace_back(buffer_size);
144-
method_holder.planned_spans.emplace_back(
145-
method_holder.planned_buffers.back().data(), buffer_size);
138+
const auto planned_buffers_count =
139+
method_metadata.num_memory_planned_buffers();
140+
method_holder.planned_buffers.reserve(planned_buffers_count);
141+
method_holder.planned_spans.reserve(planned_buffers_count);
142+
143+
for (auto index = 0; index < planned_buffers_count; ++index) {
144+
const auto buffer_size =
145+
method_metadata.memory_planned_buffer_size(index).get();
146+
method_holder.planned_buffers.emplace_back(buffer_size);
147+
method_holder.planned_spans.emplace_back(
148+
method_holder.planned_buffers.back().data(), buffer_size);
149+
}
150+
method_holder.planned_memory =
151+
std::make_unique<runtime::HierarchicalAllocator>(
152+
runtime::Span(
153+
method_holder.planned_spans.data(),
154+
method_holder.planned_spans.size()));
155+
planned_memory = method_holder.planned_memory.get();
146156
}
147-
method_holder.planned_memory =
148-
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
149-
method_holder.planned_spans.data(),
150-
method_holder.planned_spans.size()));
151157
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
152158
memory_allocator_.get(),
153-
method_holder.planned_memory.get(),
159+
planned_memory,
154160
temp_allocator_.get());
155161
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156162
method_name.c_str(),

extension/module/module.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ class Module {
133133
* needed. The loaded method is cached to reuse the next time it's executed.
134134
*
135135
* @param[in] method_name The name of the method to load.
136+
* @param[in] planned_memory The memory-planned buffers to use for mutable
137+
* tensor data when executing a method.
136138
* @param[in] event_tracer Per-method event tracer to profile/trace methods
137139
* individually. When not given, the event tracer passed to the Module
138140
* constructor is used. Otherwise, this per-method event tracer takes
@@ -143,20 +145,36 @@ class Module {
143145
ET_NODISCARD
144146
runtime::Error load_method(
145147
const std::string& method_name,
148+
runtime::HierarchicalAllocator* planned_memory = nullptr,
146149
torch::executor::EventTracer* event_tracer = nullptr);
147150

151+
ET_DEPRECATED ET_NODISCARD
152+
runtime::Error inline load_method(
153+
const std::string& method_name,
154+
torch::executor::EventTracer* event_tracer) {
155+
return load_method(method_name, nullptr, event_tracer);
156+
}
157+
148158
/**
149159
* Load the 'forward' method from the program and set up memory management if
150160
* needed. The loaded method is cached to reuse the next time it's executed.
151161
*
162+
* @param[in] planned_memory The memory-planned buffers to use for mutable
163+
* tensor data when executing the 'forward' method.
152164
* @param[in] event_tracer An event tracer used for tracking and logging
153165
* events.
154166
*
155167
* @returns An Error to indicate success or failure.
156168
*/
157169
ET_NODISCARD inline runtime::Error load_forward(
170+
runtime::HierarchicalAllocator* planned_memory = nullptr,
158171
torch::executor::EventTracer* event_tracer = nullptr) {
159-
return load_method("forward", event_tracer);
172+
return load_method("forward", planned_memory, event_tracer);
173+
}
174+
175+
ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward(
176+
torch::executor::EventTracer* event_tracer) {
177+
return load_forward(nullptr, event_tracer);
160178
}
161179

162180
/**

0 commit comments

Comments
 (0)