From 62df7e9c9b93a0e6268fc5a819bf6a4ecef55748 Mon Sep 17 00:00:00 2001 From: lucylq Date: Tue, 26 Aug 2025 11:50:08 -0700 Subject: [PATCH] Add output buffers to Module ^ Differential Revision: [D80996079](https://our.internmc.facebook.com/intern/diff/D80996079/) [ghstack-poisoned] --- extension/module/module.cpp | 40 ++++++++++++++++++++++++++++++++++++- extension/module/module.h | 3 +++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 4b82dbf4954..b3f3ba26962 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -235,8 +235,46 @@ runtime::Result> Module::execute( for (auto index = 0; index < input_values.size(); ++index) { ET_CHECK_OK_OR_RETURN_ERROR(method->set_input(input_values[index], index)); } - ET_CHECK_OK_OR_RETURN_ERROR(method->execute()); + + // Set up output storage for non-memory-planned outputs. const auto outputs_size = method->outputs_size(); + auto& method_holder = methods_.at(method_name); + auto& output_storages = method_holder.output_storages; + output_storages.clear(); + output_storages.reserve(outputs_size); + auto meta = method->method_meta(); + for (size_t i = 0; i < outputs_size; ++i) { + auto output_type = meta.output_tag(i); + if (!output_type.ok()) { + ET_LOG(Error, "Failed to get output type for output %zu", i); + return output_type.error(); + } + if (output_type.get() != executorch::runtime::Tag::Tensor) { + // Skip allocating storage for non-tensor outputs. + output_storages.emplace_back(); + continue; + } + const auto& output_tensor_meta = meta.output_tensor_meta(i); + if (!output_tensor_meta.ok()) { + ET_LOG(Error, "Failed to get output tensor meta for output %zu", i); + return output_tensor_meta.error(); + } + if (output_tensor_meta.get().is_memory_planned()) { + // Skip allocating storage for planned memory outputs. + output_storages.emplace_back(); + continue; + } + // Allocate storage for non memory planned output tensor. + const size_t output_size = output_tensor_meta.get().nbytes(); + output_storages.emplace_back(output_size); + auto output_status = method->set_output_data_ptr( + output_storages[i].data(), output_storages[i].size(), i); + if (output_status != executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to set output data ptr for output %zu", i); + return output_status; + } + } + ET_CHECK_OK_OR_RETURN_ERROR(method->execute()); std::vector outputs(outputs_size); ET_CHECK_OK_OR_RETURN_ERROR( method->get_outputs(outputs.data(), outputs_size)); diff --git a/extension/module/module.h b/extension/module/module.h index 37fd78f6fdd..4d8936e8af0 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -607,6 +607,9 @@ class Module { std::unique_ptr planned_memory; std::unique_ptr memory_manager; std::unique_ptr method; + // Keep output storages alive until they can be retrieved. + // Used when output storages are not memory planned. + std::vector> output_storages; }; std::string file_path_;