diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index dbb5bf0345b..a2a65787cb4 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -147,19 +147,19 @@ void setup_output_storage( } for (size_t i = 0; i < output_storages.size(); ++i) { if (output_storages[i].size() == 0) { - // Skip empty output storages, this would happen for non-tensor outputs. + // Skip empty output storages, this would happen for non-tensor outputs + // and memory planned outputs. continue; } Error output_status = method.set_output_data_ptr( output_storages[i].data(), output_storages[i].size(), i); - // InvalidState can be the status if outputs are already memory planned. - // That's fine and we don't need to alert the user to that error. - if (output_status != Error::Ok && output_status != Error::InvalidState) { - ET_LOG( - Error, - "Cannot set_output_data_ptr(): 0x%" PRIx32, - static_cast(output_status)); - } + // We already should be skipping non-tensor outputs, and memory planned + // outputs so any error is real. + THROW_IF_ERROR( + output_status, + "set_output_data_ptr failed for output %zu with error 0x%" PRIx32, + i, + static_cast(output_status)); } } @@ -890,26 +890,34 @@ struct PyModule final { std::vector> make_output_storages(const Method& method) { const auto num_outputs = method.outputs_size(); - // These output storages will not be used if the ExecuTorch program already - // pre-allocated output space. That is represented by an error from - // set_output_data_ptr. - std::vector> output_storages(num_outputs); + // Create a buffer for each output tensor. Memory planned outputs and non + // tensor outputs get an empty buffer in this list which is ignored later. + std::vector> output_storages; + output_storages_.reserve(num_outputs); + auto meta = method.method_meta(); for (size_t i = 0; i < num_outputs; ++i) { + auto output_type = meta.output_tag(i); + THROW_IF_ERROR( + output_type.error(), "Failed to get output type for output %zu", i); + if (output_type.get() != Tag::Tensor) { + // Skip allocating storage for non-tensor outputs. + output_storages.emplace_back(); + continue; + } const auto& output_tensor_meta = method.method_meta().output_tensor_meta(i); - if (!output_tensor_meta.ok()) { - // If the output isn't a tensor it won't have a tensor meta. - ET_LOG( - Error, - "Tensor meta doesn't exist for output %zu, error is 0x%" PRIx32 - ", skipping allocating storage", - i, - static_cast(output_tensor_meta.error())); - output_storages[i] = std::vector(); + THROW_IF_ERROR( + output_tensor_meta.error(), + "Failed to get output tensor meta for output %zu", + i); + if (output_tensor_meta.get().is_memory_planned()) { + // Skip allocating storage for planned memory outputs. + output_storages.emplace_back(); continue; } + // Allocate storage for the output tensor. const size_t output_size = output_tensor_meta.get().nbytes(); - output_storages[i] = std::vector(output_size); + output_storages.emplace_back(output_size); } return output_storages; }