From ce80bf869a66dd9634246b1f43a5c1b9042ded29 Mon Sep 17 00:00:00 2001 From: cptspacemanspiff <19273992+cptspacemanspiff@users.noreply.github.com> Date: Tue, 7 Jan 2025 11:33:47 -0800 Subject: [PATCH 1/3] Enhance load_method to support optional planned memory allocator - Updated the load_method signature to accept an optional runtime::HierarchicalAllocator parameter. --- extension/module/module.cpp | 50 ++++++++++++++++++++++--------------- extension/module/module.h | 3 ++- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index aa750e2691e..be84b713129 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -12,7 +12,9 @@ #include #include #include +#include #include +#include /** * Unwrap a Result to obtain its value (direct object, not a pointer). @@ -178,34 +180,42 @@ runtime::Result> Module::method_names() { runtime::Error Module::load_method( const std::string& method_name, - torch::executor::EventTracer* event_tracer) { + torch::executor::EventTracer* event_tracer, + runtime::HierarchicalAllocator* planned_memory_allocator) { if (!is_method_loaded(method_name)) { ET_CHECK_OK_OR_RETURN_ERROR(load()); MethodHolder method_holder; + runtime::HierarchicalAllocator* planned_memory = nullptr; - const auto method_metadata = - ET_UNWRAP(program_->method_meta(method_name.c_str())); - const auto planned_buffersCount = - method_metadata.num_memory_planned_buffers(); - method_holder.planned_buffers.reserve(planned_buffersCount); - method_holder.planned_spans.reserve(planned_buffersCount); + // we were not given a planned memory allocator, so we need to create one: + if (planned_memory_allocator == nullptr) { + const auto method_metadata = + ET_UNWRAP(program_->method_meta(method_name.c_str())); + const auto planned_buffersCount = + method_metadata.num_memory_planned_buffers(); + method_holder.planned_buffers.reserve(planned_buffersCount); + method_holder.planned_spans.reserve(planned_buffersCount); - for (auto index = 0; index < planned_buffersCount; ++index) { - const auto buffer_size = - method_metadata.memory_planned_buffer_size(index).get(); - method_holder.planned_buffers.emplace_back(buffer_size); - method_holder.planned_spans.emplace_back( - method_holder.planned_buffers.back().data(), buffer_size); + for (auto index = 0; index < planned_buffersCount; ++index) { + const auto buffer_size = + method_metadata.memory_planned_buffer_size(index).get(); + method_holder.planned_buffers.emplace_back(buffer_size); + method_holder.planned_spans.emplace_back( + method_holder.planned_buffers.back().data(), buffer_size); + } + method_holder.planned_memory = + std::make_unique(runtime::Span( + method_holder.planned_spans.data(), + method_holder.planned_spans.size())); + planned_memory = method_holder.planned_memory.get(); + } else { + // we were given a planned memory allocator, so we use it: + planned_memory = planned_memory_allocator; } - method_holder.planned_memory = - std::make_unique(runtime::Span( - method_holder.planned_spans.data(), - method_holder.planned_spans.size())); + method_holder.memory_manager = std::make_unique( - memory_allocator_.get(), - method_holder.planned_memory.get(), - temp_allocator_.get()); + memory_allocator_.get(), planned_memory, temp_allocator_.get()); method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( method_name.c_str(), method_holder.memory_manager.get(), diff --git a/extension/module/module.h b/extension/module/module.h index dc7c930d7c6..11fe73c8d54 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -162,7 +162,8 @@ class Module { ET_NODISCARD runtime::Error load_method( const std::string& method_name, - torch::executor::EventTracer* event_tracer = nullptr); + torch::executor::EventTracer* event_tracer = nullptr, + runtime::HierarchicalAllocator* planned_memory_allocator = nullptr); /** * Load the 'forward' method from the program and set up memory management if From 3e88cc21e84df7831dfa16099c5186896305b70e Mon Sep 17 00:00:00 2001 From: Nicholas Long <19273992+cptspacemanspiff@users.noreply.github.com> Date: Sun, 9 Feb 2025 05:28:24 -0800 Subject: [PATCH 2/3] Cleaned up methods, deprecated old interfaces. --- extension/module/module.cpp | 35 +++++++++++++++-------------------- extension/module/module.h | 23 ++++++++++++++++++++--- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index be84b713129..e5879bbcb1b 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -12,9 +12,7 @@ #include #include #include -#include #include -#include /** * Unwrap a Result to obtain its value (direct object, not a pointer). @@ -180,24 +178,22 @@ runtime::Result> Module::method_names() { runtime::Error Module::load_method( const std::string& method_name, - torch::executor::EventTracer* event_tracer, - runtime::HierarchicalAllocator* planned_memory_allocator) { + runtime::HierarchicalAllocator* planned_memory, + torch::executor::EventTracer* event_tracer) { if (!is_method_loaded(method_name)) { ET_CHECK_OK_OR_RETURN_ERROR(load()); MethodHolder method_holder; - runtime::HierarchicalAllocator* planned_memory = nullptr; - // we were not given a planned memory allocator, so we need to create one: - if (planned_memory_allocator == nullptr) { + if (!planned_memory) { const auto method_metadata = - ET_UNWRAP(program_->method_meta(method_name.c_str())); - const auto planned_buffersCount = + ET_UNWRAP(program_->method_meta(method_name.c_str())); + const auto planned_buffers_count = method_metadata.num_memory_planned_buffers(); - method_holder.planned_buffers.reserve(planned_buffersCount); - method_holder.planned_spans.reserve(planned_buffersCount); + method_holder.planned_buffers.reserve(planned_buffers_count); + method_holder.planned_spans.reserve(planned_buffers_count); - for (auto index = 0; index < planned_buffersCount; ++index) { + for (auto index = 0; index < planned_buffers_count; ++index) { const auto buffer_size = method_metadata.memory_planned_buffer_size(index).get(); method_holder.planned_buffers.emplace_back(buffer_size); @@ -205,17 +201,16 @@ runtime::Error Module::load_method( method_holder.planned_buffers.back().data(), buffer_size); } method_holder.planned_memory = - std::make_unique(runtime::Span( - method_holder.planned_spans.data(), - method_holder.planned_spans.size())); + std::make_unique( + runtime::Span( + method_holder.planned_spans.data(), + method_holder.planned_spans.size())); planned_memory = method_holder.planned_memory.get(); - } else { - // we were given a planned memory allocator, so we use it: - planned_memory = planned_memory_allocator; } - method_holder.memory_manager = std::make_unique( - memory_allocator_.get(), planned_memory, temp_allocator_.get()); + memory_allocator_.get(), + planned_memory, + temp_allocator_.get()); method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( method_name.c_str(), method_holder.memory_manager.get(), diff --git a/extension/module/module.h b/extension/module/module.h index 11fe73c8d54..74254229cae 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -152,6 +152,8 @@ class Module { * needed. The loaded method is cached to reuse the next time it's executed. * * @param[in] method_name The name of the method to load. + * @param[in] planned_memory The memory-planned buffers to use for mutable + * tensor data when executing a method. * @param[in] event_tracer Per-method event tracer to profile/trace methods * individually. When not given, the event tracer passed to the Module * constructor is used. Otherwise, this per-method event tracer takes @@ -162,21 +164,36 @@ class Module { ET_NODISCARD runtime::Error load_method( const std::string& method_name, - torch::executor::EventTracer* event_tracer = nullptr, - runtime::HierarchicalAllocator* planned_memory_allocator = nullptr); + runtime::HierarchicalAllocator* planned_memory = nullptr, + torch::executor::EventTracer* event_tracer = nullptr); + + ET_DEPRECATED ET_NODISCARD + runtime::Error inline load_method( + const std::string& method_name, + torch::executor::EventTracer* event_tracer) { + return load_method(method_name, nullptr, event_tracer); + } /** * Load the 'forward' method from the program and set up memory management if * needed. The loaded method is cached to reuse the next time it's executed. * + * @param[in] planned_memory The memory-planned buffers to use for mutable + * tensor data when executing the 'forward' method. * @param[in] event_tracer An event tracer used for tracking and logging * events. * * @returns An Error to indicate success or failure. */ ET_NODISCARD inline runtime::Error load_forward( + runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr) { - return load_method("forward", event_tracer); + return load_method("forward", planned_memory, event_tracer); + } + + ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward( + torch::executor::EventTracer* event_tracer) { + return load_forward(nullptr, event_tracer); } /** From b55c86f4d2b5749dec86326bc12d8781da7c8927 Mon Sep 17 00:00:00 2001 From: Nicholas Long <19273992+cptspacemanspiff@users.noreply.github.com> Date: Sun, 9 Feb 2025 05:34:40 -0800 Subject: [PATCH 3/3] Fixed linter errors. --- extension/module/module.cpp | 13 +++++-------- extension/module/module.h | 7 +++---- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index e5879bbcb1b..26e74e84364 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -187,7 +187,7 @@ runtime::Error Module::load_method( if (!planned_memory) { const auto method_metadata = - ET_UNWRAP(program_->method_meta(method_name.c_str())); + ET_UNWRAP(program_->method_meta(method_name.c_str())); const auto planned_buffers_count = method_metadata.num_memory_planned_buffers(); method_holder.planned_buffers.reserve(planned_buffers_count); @@ -201,16 +201,13 @@ runtime::Error Module::load_method( method_holder.planned_buffers.back().data(), buffer_size); } method_holder.planned_memory = - std::make_unique( - runtime::Span( - method_holder.planned_spans.data(), - method_holder.planned_spans.size())); + std::make_unique(runtime::Span( + method_holder.planned_spans.data(), + method_holder.planned_spans.size())); planned_memory = method_holder.planned_memory.get(); } method_holder.memory_manager = std::make_unique( - memory_allocator_.get(), - planned_memory, - temp_allocator_.get()); + memory_allocator_.get(), planned_memory, temp_allocator_.get()); method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( method_name.c_str(), method_holder.memory_manager.get(), diff --git a/extension/module/module.h b/extension/module/module.h index 74254229cae..d58a447fdba 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -167,10 +167,9 @@ class Module { runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr); - ET_DEPRECATED ET_NODISCARD - runtime::Error inline load_method( - const std::string& method_name, - torch::executor::EventTracer* event_tracer) { + ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method( + const std::string& method_name, + torch::executor::EventTracer* event_tracer) { return load_method(method_name, nullptr, event_tracer); }