|
12 | 12 | #include <executorch/extension/data_loader/mmap_data_loader.h> |
13 | 13 | #include <executorch/extension/flat_tensor/flat_tensor_data_map.h> |
14 | 14 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h> |
| 15 | +#include <executorch/runtime/core/hierarchical_allocator.h> |
15 | 16 | #include <executorch/runtime/platform/runtime.h> |
| 17 | +#include <memory> |
16 | 18 |
|
17 | 19 | /** |
18 | 20 | * Unwrap a Result to obtain its value (direct object, not a pointer). |
@@ -178,34 +180,42 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() { |
178 | 180 |
|
179 | 181 | runtime::Error Module::load_method( |
180 | 182 | const std::string& method_name, |
181 | | - torch::executor::EventTracer* event_tracer) { |
| 183 | + torch::executor::EventTracer* event_tracer, |
| 184 | + runtime::HierarchicalAllocator* planned_memory_allocator) { |
182 | 185 | if (!is_method_loaded(method_name)) { |
183 | 186 | ET_CHECK_OK_OR_RETURN_ERROR(load()); |
184 | 187 |
|
185 | 188 | MethodHolder method_holder; |
| 189 | + runtime::HierarchicalAllocator* planned_memory = nullptr; |
186 | 190 |
|
187 | | - const auto method_metadata = |
188 | | - ET_UNWRAP(program_->method_meta(method_name.c_str())); |
189 | | - const auto planned_buffersCount = |
190 | | - method_metadata.num_memory_planned_buffers(); |
191 | | - method_holder.planned_buffers.reserve(planned_buffersCount); |
192 | | - method_holder.planned_spans.reserve(planned_buffersCount); |
| 191 | + // we were not given a planned memory allocator, so we need to create one: |
| 192 | + if (planned_memory_allocator == nullptr) { |
| 193 | + const auto method_metadata = |
| 194 | + ET_UNWRAP(program_->method_meta(method_name.c_str())); |
| 195 | + const auto planned_buffersCount = |
| 196 | + method_metadata.num_memory_planned_buffers(); |
| 197 | + method_holder.planned_buffers.reserve(planned_buffersCount); |
| 198 | + method_holder.planned_spans.reserve(planned_buffersCount); |
193 | 199 |
|
194 | | - for (auto index = 0; index < planned_buffersCount; ++index) { |
195 | | - const auto buffer_size = |
196 | | - method_metadata.memory_planned_buffer_size(index).get(); |
197 | | - method_holder.planned_buffers.emplace_back(buffer_size); |
198 | | - method_holder.planned_spans.emplace_back( |
199 | | - method_holder.planned_buffers.back().data(), buffer_size); |
| 200 | + for (auto index = 0; index < planned_buffersCount; ++index) { |
| 201 | + const auto buffer_size = |
| 202 | + method_metadata.memory_planned_buffer_size(index).get(); |
| 203 | + method_holder.planned_buffers.emplace_back(buffer_size); |
| 204 | + method_holder.planned_spans.emplace_back( |
| 205 | + method_holder.planned_buffers.back().data(), buffer_size); |
| 206 | + } |
| 207 | + method_holder.planned_memory = |
| 208 | + std::make_unique<runtime::HierarchicalAllocator>(runtime::Span( |
| 209 | + method_holder.planned_spans.data(), |
| 210 | + method_holder.planned_spans.size())); |
| 211 | + planned_memory = method_holder.planned_memory.get(); |
| 212 | + } else { |
| 213 | + // we were given a planned memory allocator, so we use it: |
| 214 | + planned_memory = planned_memory_allocator; |
200 | 215 | } |
201 | | - method_holder.planned_memory = |
202 | | - std::make_unique<runtime::HierarchicalAllocator>(runtime::Span( |
203 | | - method_holder.planned_spans.data(), |
204 | | - method_holder.planned_spans.size())); |
| 216 | + |
205 | 217 | method_holder.memory_manager = std::make_unique<runtime::MemoryManager>( |
206 | | - memory_allocator_.get(), |
207 | | - method_holder.planned_memory.get(), |
208 | | - temp_allocator_.get()); |
| 218 | + memory_allocator_.get(), planned_memory, temp_allocator_.get()); |
209 | 219 | method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( |
210 | 220 | method_name.c_str(), |
211 | 221 | method_holder.memory_manager.get(), |
|
0 commit comments