diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 99cc7e38bd6..598c941bae9 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -125,32 +125,38 @@ runtime::Result> Module::method_names() { runtime::Error Module::load_method( const std::string& method_name, + runtime::HierarchicalAllocator* planned_memory, torch::executor::EventTracer* event_tracer) { if (!is_method_loaded(method_name)) { ET_CHECK_OK_OR_RETURN_ERROR(load()); MethodHolder method_holder; - const auto method_metadata = + + if (!planned_memory) { + const auto method_metadata = ET_UNWRAP(program_->method_meta(method_name.c_str())); - const auto planned_buffersCount = - method_metadata.num_memory_planned_buffers(); - method_holder.planned_buffers.reserve(planned_buffersCount); - method_holder.planned_spans.reserve(planned_buffersCount); - - for (auto index = 0; index < planned_buffersCount; ++index) { - const auto buffer_size = - method_metadata.memory_planned_buffer_size(index).get(); - method_holder.planned_buffers.emplace_back(buffer_size); - method_holder.planned_spans.emplace_back( - method_holder.planned_buffers.back().data(), buffer_size); + const auto planned_buffers_count = + method_metadata.num_memory_planned_buffers(); + method_holder.planned_buffers.reserve(planned_buffers_count); + method_holder.planned_spans.reserve(planned_buffers_count); + + for (auto index = 0; index < planned_buffers_count; ++index) { + const auto buffer_size = + method_metadata.memory_planned_buffer_size(index).get(); + method_holder.planned_buffers.emplace_back(buffer_size); + method_holder.planned_spans.emplace_back( + method_holder.planned_buffers.back().data(), buffer_size); + } + method_holder.planned_memory = + std::make_unique( + runtime::Span( + method_holder.planned_spans.data(), + method_holder.planned_spans.size())); + planned_memory = method_holder.planned_memory.get(); } - method_holder.planned_memory = - std::make_unique(runtime::Span( - method_holder.planned_spans.data(), - method_holder.planned_spans.size())); method_holder.memory_manager = std::make_unique( memory_allocator_.get(), - method_holder.planned_memory.get(), + planned_memory, temp_allocator_.get()); method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( method_name.c_str(), diff --git a/extension/module/module.h b/extension/module/module.h index 45ed38a7ff2..737a8969e67 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -133,6 +133,8 @@ class Module { * needed. The loaded method is cached to reuse the next time it's executed. * * @param[in] method_name The name of the method to load. + * @param[in] planned_memory The memory-planned buffers to use for mutable + * tensor data when executing a method. * @param[in] event_tracer Per-method event tracer to profile/trace methods * individually. When not given, the event tracer passed to the Module * constructor is used. Otherwise, this per-method event tracer takes @@ -143,20 +145,36 @@ class Module { ET_NODISCARD runtime::Error load_method( const std::string& method_name, + runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr); + ET_DEPRECATED ET_NODISCARD + runtime::Error inline load_method( + const std::string& method_name, + torch::executor::EventTracer* event_tracer) { + return load_method(method_name, nullptr, event_tracer); + } + /** * Load the 'forward' method from the program and set up memory management if * needed. The loaded method is cached to reuse the next time it's executed. * + * @param[in] planned_memory The memory-planned buffers to use for mutable + * tensor data when executing the 'forward' method. * @param[in] event_tracer An event tracer used for tracking and logging * events. * * @returns An Error to indicate success or failure. */ ET_NODISCARD inline runtime::Error load_forward( + runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr) { - return load_method("forward", event_tracer); + return load_method("forward", planned_memory, event_tracer); + } + + ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward( + torch::executor::EventTracer* event_tracer) { + return load_forward(nullptr, event_tracer); } /**