@@ -125,32 +125,38 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
125125
126126runtime::Error Module::load_method (
127127 const std::string& method_name,
128+ runtime::HierarchicalAllocator* planned_memory,
128129 torch::executor::EventTracer* event_tracer) {
129130 if (!is_method_loaded (method_name)) {
130131 ET_CHECK_OK_OR_RETURN_ERROR (load ());
131132
132133 MethodHolder method_holder;
133- const auto method_metadata =
134+
135+ if (!planned_memory) {
136+ const auto method_metadata =
134137 ET_UNWRAP (program_->method_meta (method_name.c_str ()));
135- const auto planned_buffersCount =
136- method_metadata.num_memory_planned_buffers ();
137- method_holder.planned_buffers .reserve (planned_buffersCount);
138- method_holder.planned_spans .reserve (planned_buffersCount);
139-
140- for (auto index = 0 ; index < planned_buffersCount; ++index) {
141- const auto buffer_size =
142- method_metadata.memory_planned_buffer_size (index).get ();
143- method_holder.planned_buffers .emplace_back (buffer_size);
144- method_holder.planned_spans .emplace_back (
145- method_holder.planned_buffers .back ().data (), buffer_size);
138+ const auto planned_buffers_count =
139+ method_metadata.num_memory_planned_buffers ();
140+ method_holder.planned_buffers .reserve (planned_buffers_count);
141+ method_holder.planned_spans .reserve (planned_buffers_count);
142+
143+ for (auto index = 0 ; index < planned_buffers_count; ++index) {
144+ const auto buffer_size =
145+ method_metadata.memory_planned_buffer_size (index).get ();
146+ method_holder.planned_buffers .emplace_back (buffer_size);
147+ method_holder.planned_spans .emplace_back (
148+ method_holder.planned_buffers .back ().data (), buffer_size);
149+ }
150+ method_holder.planned_memory =
151+ std::make_unique<runtime::HierarchicalAllocator>(
152+ runtime::Span (
153+ method_holder.planned_spans .data (),
154+ method_holder.planned_spans .size ()));
155+ planned_memory = method_holder.planned_memory .get ();
146156 }
147- method_holder.planned_memory =
148- std::make_unique<runtime::HierarchicalAllocator>(runtime::Span (
149- method_holder.planned_spans .data (),
150- method_holder.planned_spans .size ()));
151157 method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
152158 memory_allocator_.get (),
153- method_holder. planned_memory . get () ,
159+ planned_memory,
154160 temp_allocator_.get ());
155161 method_holder.method = ET_UNWRAP_UNIQUE (program_->load_method (
156162 method_name.c_str (),
0 commit comments