Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,46 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
for (auto index = 0; index < input_values.size(); ++index) {
ET_CHECK_OK_OR_RETURN_ERROR(method->set_input(input_values[index], index));
}
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());

// Set up output storage for non-memory-planned outputs.
const auto outputs_size = method->outputs_size();
auto& method_holder = methods_.at(method_name);
auto& output_storages = method_holder.output_storages;
output_storages.clear();
output_storages.reserve(outputs_size);
auto meta = method->method_meta();
for (size_t i = 0; i < outputs_size; ++i) {
auto output_type = meta.output_tag(i);
if (!output_type.ok()) {
ET_LOG(Error, "Failed to get output type for output %zu", i);
return output_type.error();
}
if (output_type.get() != executorch::runtime::Tag::Tensor) {
// Skip allocating storage for non-tensor outputs.
output_storages.emplace_back();
continue;
}
const auto& output_tensor_meta = meta.output_tensor_meta(i);
if (!output_tensor_meta.ok()) {
ET_LOG(Error, "Failed to get output tensor meta for output %zu", i);
return output_tensor_meta.error();
}
if (output_tensor_meta.get().is_memory_planned()) {
// Skip allocating storage for planned memory outputs.
output_storages.emplace_back();
continue;
}
// Allocate storage for non memory planned output tensor.
const size_t output_size = output_tensor_meta.get().nbytes();
output_storages.emplace_back(output_size);
auto output_status = method->set_output_data_ptr(
output_storages[i].data(), output_storages[i].size(), i);
if (output_status != executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to set output data ptr for output %zu", i);
return output_status;
}
}
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
std::vector<runtime::EValue> outputs(outputs_size);
ET_CHECK_OK_OR_RETURN_ERROR(
method->get_outputs(outputs.data(), outputs_size));
Expand Down
3 changes: 3 additions & 0 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,9 @@ class Module {
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
std::unique_ptr<runtime::MemoryManager> memory_manager;
std::unique_ptr<Method> method;
// Keep output storages alive until they can be retrieved.
// Used when output storages are not memory planned.
std::vector<std::vector<uint8_t>> output_storages;
};

std::string file_path_;
Expand Down
Loading