|
11 | 11 | #include <executorch/extension/data_loader/file_data_loader.h> |
12 | 12 | #include <executorch/extension/data_loader/mmap_data_loader.h> |
13 | 13 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h> |
14 | | -#include <executorch/runtime/core/hierarchical_allocator.h> |
15 | 14 | #include <executorch/runtime/platform/runtime.h> |
16 | | -#include <memory> |
17 | 15 |
|
18 | 16 | /** |
19 | 17 | * Unwrap a Result to obtain its value (direct object, not a pointer). |
@@ -127,42 +125,39 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() { |
127 | 125 |
|
128 | 126 | runtime::Error Module::load_method( |
129 | 127 | const std::string& method_name, |
130 | | - torch::executor::EventTracer* event_tracer, |
131 | | - runtime::HierarchicalAllocator* planned_memory_allocator) { |
| 128 | + runtime::HierarchicalAllocator* planned_memory, |
| 129 | + torch::executor::EventTracer* event_tracer) { |
132 | 130 | if (!is_method_loaded(method_name)) { |
133 | 131 | ET_CHECK_OK_OR_RETURN_ERROR(load()); |
134 | 132 |
|
135 | 133 | MethodHolder method_holder; |
136 | | - runtime::HierarchicalAllocator* planned_memory = nullptr; |
137 | 134 |
|
138 | | - // we were not given a planned memory allocator, so we need to create one: |
139 | | - if (planned_memory_allocator == nullptr) { |
| 135 | + if (!planned_memory) { |
140 | 136 | const auto method_metadata = |
141 | | - ET_UNWRAP(program_->method_meta(method_name.c_str())); |
142 | | - const auto planned_buffersCount = |
| 137 | + ET_UNWRAP(program_->method_meta(method_name.c_str())); |
| 138 | + const auto planned_buffers_count = |
143 | 139 | method_metadata.num_memory_planned_buffers(); |
144 | | - method_holder.planned_buffers.reserve(planned_buffersCount); |
145 | | - method_holder.planned_spans.reserve(planned_buffersCount); |
| 140 | + method_holder.planned_buffers.reserve(planned_buffers_count); |
| 141 | + method_holder.planned_spans.reserve(planned_buffers_count); |
146 | 142 |
|
147 | | - for (auto index = 0; index < planned_buffersCount; ++index) { |
| 143 | + for (auto index = 0; index < planned_buffers_count; ++index) { |
148 | 144 | const auto buffer_size = |
149 | 145 | method_metadata.memory_planned_buffer_size(index).get(); |
150 | 146 | method_holder.planned_buffers.emplace_back(buffer_size); |
151 | 147 | method_holder.planned_spans.emplace_back( |
152 | 148 | method_holder.planned_buffers.back().data(), buffer_size); |
153 | 149 | } |
154 | 150 | method_holder.planned_memory = |
155 | | - std::make_unique<runtime::HierarchicalAllocator>(runtime::Span( |
156 | | - method_holder.planned_spans.data(), |
157 | | - method_holder.planned_spans.size())); |
| 151 | + std::make_unique<runtime::HierarchicalAllocator>( |
| 152 | + runtime::Span( |
| 153 | + method_holder.planned_spans.data(), |
| 154 | + method_holder.planned_spans.size())); |
158 | 155 | planned_memory = method_holder.planned_memory.get(); |
159 | | - } else { |
160 | | - // we were given a planned memory allocator, so we use it: |
161 | | - planned_memory = planned_memory_allocator; |
162 | 156 | } |
163 | | - |
164 | 157 | method_holder.memory_manager = std::make_unique<runtime::MemoryManager>( |
165 | | - memory_allocator_.get(), planned_memory, temp_allocator_.get()); |
| 158 | + memory_allocator_.get(), |
| 159 | + planned_memory, |
| 160 | + temp_allocator_.get()); |
166 | 161 | method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( |
167 | 162 | method_name.c_str(), |
168 | 163 | method_holder.memory_manager.get(), |
|
0 commit comments