Skip to content

Commit f90a836

Browse files
authored
Add set_outputs() API. (#13609)
Summary: . Differential Revision: D80845634
1 parent 87512b8 commit f90a836

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

extension/module/module.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,25 @@ runtime::Error Module::set_output(
278278
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
279279
}
280280

281+
runtime::Error Module::set_outputs(
282+
const std::string& method_name,
283+
const std::vector<runtime::EValue>& output_values) {
284+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
285+
auto& method = methods_.at(method_name).method;
286+
const auto outputs_size = method->outputs_size();
287+
ET_CHECK_OR_RETURN_ERROR(
288+
output_values.size() == outputs_size,
289+
InvalidArgument,
290+
"output size: %zu is not equal to method output size: %zu",
291+
output_values.size(),
292+
outputs_size);
293+
for (auto index = 0; index < outputs_size; ++index) {
294+
ET_CHECK_OK_OR_RETURN_ERROR(
295+
set_output(method_name, output_values[index], index));
296+
}
297+
return runtime::Error::Ok;
298+
}
299+
281300
} // namespace ET_MODULE_NAMESPACE
282301
} // namespace extension
283302
} // namespace executorch

extension/module/module.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,41 @@ class Module {
498498
return set_output("forward", std::move(output_value), output_index);
499499
}
500500

501+
/**
502+
* Sets all output tensors for a specific method.
503+
*
504+
* Loads the program and method if needed, and for each output uses
505+
* the provided tensor's data buffer as the method's output buffer.
506+
*
507+
* @param[in] method_name The name of the method.
508+
* @param[in] output_values A vector of EValues to set as the method outputs.
509+
*
510+
* @returns An Error to indicate success or failure.
511+
*
512+
* @note Only Tensor outputs are currently supported for setting.
513+
* @note Will fail for outputs that are memory-planned or constants.
514+
*/
515+
ET_NODISCARD
516+
runtime::Error set_outputs(
517+
const std::string& method_name,
518+
const std::vector<runtime::EValue>& output_values);
519+
520+
/**
521+
* Sets all output tensors for the "forward" method.
522+
*
523+
* @param[in] output_values A vector of EValues to set as the method outputs.
524+
*
525+
* @returns An Error to indicate success or failure.
526+
*
527+
* @note Only Tensor outputs are currently supported for setting.
528+
* @note Will fail for outputs that are memory-planned or constants.
529+
*/
530+
ET_NODISCARD
531+
inline runtime::Error set_outputs(
532+
const std::vector<runtime::EValue>& output_values) {
533+
return set_outputs("forward", output_values);
534+
}
535+
501536
/**
502537
* Retrieves the EventTracer instance being used by the Module.
503538
* EventTracer is used for tracking and logging events during the execution

extension/module/test/module_test.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,24 @@ TEST_F(ModuleTest, TestSetOutputInvalidType) {
477477
EXPECT_NE(module.set_output(EValue()), Error::Ok);
478478
}
479479

480+
TEST_F(ModuleTest, TestSetOutputsCountMismatch) {
481+
Module module(model_path_);
482+
483+
EXPECT_NE(module.set_outputs(std::vector<EValue>{}), Error::Ok);
484+
}
485+
486+
TEST_F(ModuleTest, TestSetOutputsInvalidType) {
487+
Module module(model_path_);
488+
489+
EXPECT_NE(module.set_outputs({EValue()}), Error::Ok);
490+
}
491+
492+
TEST_F(ModuleTest, TestSetOutputsMemoryPlanned) {
493+
Module module(model_path_);
494+
495+
EXPECT_NE(module.set_outputs({empty({1})}), Error::Ok);
496+
}
497+
480498
TEST_F(ModuleTest, TestPTD) {
481499
Module module(add_mul_path_, add_mul_data_path_);
482500

0 commit comments

Comments
 (0)