Skip to content

Commit bbc281f

Browse files
authored
Add get_output API. (#13610)
Summary: . Differential Revision: D80845633
1 parent 3bb42ec commit bbc281f

File tree

4 files changed

+102
-1
lines changed

4 files changed

+102
-1
lines changed

extension/module/module.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,30 @@ runtime::Error Module::set_outputs(
297297
return runtime::Error::Ok;
298298
}
299299

300+
runtime::Result<std::vector<runtime::EValue>> Module::get_outputs(
301+
const std::string& method_name) {
302+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
303+
auto& method = methods_.at(method_name).method;
304+
const auto outputs_size = method->outputs_size();
305+
std::vector<runtime::EValue> outputs(outputs_size);
306+
ET_CHECK_OK_OR_RETURN_ERROR(
307+
method->get_outputs(outputs.data(), outputs_size));
308+
return outputs;
309+
}
310+
311+
runtime::Result<runtime::EValue> Module::get_output(
312+
const std::string& method_name,
313+
size_t output_index) {
314+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
315+
auto& method = methods_.at(method_name).method;
316+
ET_CHECK_OR_RETURN_ERROR(
317+
output_index < method->outputs_size(),
318+
InvalidArgument,
319+
"output index: %zu is out of range",
320+
output_index);
321+
return method->get_output(output_index);
322+
}
323+
300324
} // namespace ET_MODULE_NAMESPACE
301325
} // namespace extension
302326
} // namespace executorch

extension/module/module.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,56 @@ class Module {
533533
return set_outputs("forward", output_values);
534534
}
535535

536+
/**
537+
* Retrieve all current output values of a specific method without executing
538+
* it. Loads the program and method before retrieval if needed.
539+
*
540+
* @param[in] method_name The name of the method.
541+
*
542+
* @returns A Result containing the vector of output values, or an error.
543+
*/
544+
ET_NODISCARD
545+
runtime::Result<std::vector<runtime::EValue>> get_outputs(
546+
const std::string& method_name);
547+
548+
/**
549+
* Retrieve all current output values of the "forward" method without
550+
* executing it. Loads the program and method before retrieval if needed.
551+
*
552+
* @returns A Result containing the vector of output values, or an error.
553+
*/
554+
ET_NODISCARD
555+
inline runtime::Result<std::vector<runtime::EValue>> get_outputs() {
556+
return get_outputs("forward");
557+
}
558+
559+
/**
560+
* Retrieve a single current output value of a specific method without
561+
* executing it. Loads the program and method before retrieval if needed.
562+
*
563+
* @param[in] method_name The name of the method.
564+
* @param[in] output_index Zero-based index of the output to retrieve.
565+
*
566+
* @returns A Result containing the requested output value, or an error.
567+
*/
568+
ET_NODISCARD
569+
runtime::Result<runtime::EValue> get_output(
570+
const std::string& method_name,
571+
size_t output_index = 0);
572+
573+
/**
574+
* Retrieve a single current output value of the "forward" method without
575+
* executing it. Loads the program and method before retrieval if needed.
576+
*
577+
* @param[in] output_index Zero-based index of the output to retrieve.
578+
*
579+
* @returns A Result containing the requested output value, or an error.
580+
*/
581+
ET_NODISCARD
582+
inline runtime::Result<runtime::EValue> get_output(size_t output_index = 0) {
583+
return get_output("forward", output_index);
584+
}
585+
536586
/**
537587
* Retrieves the EventTracer instance being used by the Module.
538588
* EventTracer is used for tracking and logging events during the execution

extension/module/test/module_test.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,33 @@ TEST_F(ModuleTest, TestSetOutputsMemoryPlanned) {
495495
EXPECT_NE(module.set_outputs({empty({1})}), Error::Ok);
496496
}
497497

498+
TEST_F(ModuleTest, TestGetOutputAndGetOutputs) {
499+
Module module(model_path_);
500+
501+
auto tensor = make_tensor_ptr({2, 2}, {1.f, 2.f, 3.f, 4.f});
502+
503+
ASSERT_EQ(module.forward({tensor, tensor, 1.0}).error(), Error::Ok);
504+
505+
const auto single = module.get_output();
506+
EXPECT_EQ(single.error(), Error::Ok);
507+
const auto expected = make_tensor_ptr({2, 2}, {2.f, 4.f, 6.f, 8.f});
508+
EXPECT_TENSOR_CLOSE(single->toTensor(), *expected.get());
509+
510+
const auto all = module.get_outputs();
511+
EXPECT_EQ(all.error(), Error::Ok);
512+
ASSERT_EQ(all->size(), 1);
513+
EXPECT_TENSOR_CLOSE(all->at(0).toTensor(), *expected.get());
514+
}
515+
516+
TEST_F(ModuleTest, TestGetOutputInvalidIndex) {
517+
Module module(model_path_);
518+
519+
ASSERT_EQ(module.load_method("forward"), Error::Ok);
520+
521+
const auto bad = module.get_output("forward", 99);
522+
EXPECT_NE(bad.error(), Error::Ok);
523+
}
524+
498525
TEST_F(ModuleTest, TestPTD) {
499526
Module module(add_mul_path_, add_mul_data_path_);
500527

runtime/executor/method.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ ET_NODISCARD Error Method::get_outputs(EValue* output_evalues, size_t length) {
12781278
InvalidArgument,
12791279
"The given array is not large enough to hold all outputs.");
12801280
for (size_t i = 0; i < n_output; ++i) {
1281-
output_evalues[i] = values_[get_output_index(i)];
1281+
output_evalues[i] = get_output(i);
12821282
}
12831283
for (size_t i = n_output; i < length; ++i) {
12841284
output_evalues[i] = EValue();

0 commit comments

Comments
 (0)