diff --git a/backends/xnnpack/test/runtime/test_xnn_data_separation.cpp b/backends/xnnpack/test/runtime/test_xnn_data_separation.cpp index 342e3478e0f..85cac66c62d 100644 --- a/backends/xnnpack/test/runtime/test_xnn_data_separation.cpp +++ b/backends/xnnpack/test/runtime/test_xnn_data_separation.cpp @@ -108,6 +108,21 @@ TEST_F(DataSeparationTest, TestE2E) { "forward", &mmm.get(), nullptr, linear_data_map_.get()); ASSERT_EQ(method.error(), Error::Ok); + // Set a dummy input. + int32_t sizes[1] = {3}; + uint8_t dim_order[1] = {0}; + int32_t strides[1] = {1}; + executorch::aten::TensorImpl impl( + executorch::aten::ScalarType::Float, + 1, + sizes, + nullptr, + dim_order, + strides); + auto input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 0); + ASSERT_EQ(input_err, Error::Ok); + // Can execute the method. Error err = method->execute(); ASSERT_EQ(err, Error::Ok); diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 0a33deabd9e..76304d20e25 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -210,7 +210,6 @@ runtime::Error Module::load_method( method_holder.memory_manager.get(), event_tracer ? event_tracer : this->event_tracer(), data_map_.get())); - method_holder.inputs.resize(method_holder.method->inputs_size()); methods_.emplace(method_name, std::move(method_holder)); } return runtime::Error::Ok; @@ -233,28 +232,10 @@ runtime::Result> Module::execute( const std::vector& input_values) { ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); auto& method = methods_.at(method_name).method; - auto& inputs = methods_.at(method_name).inputs; - - ET_CHECK_OR_RETURN_ERROR( - input_values.size() <= inputs.size(), - InvalidArgument, - "input size: %zu does not match method input size: %zu", - input_values.size(), - inputs.size()); - for (size_t i = 0; i < input_values.size(); ++i) { - if (!input_values[i].isNone()) { - inputs[i] = input_values[i]; - } + for (auto index = 0; index < input_values.size(); ++index) { + ET_CHECK_OK_OR_RETURN_ERROR(method->set_input(input_values[index], index)); } - for (size_t i = 0; i < inputs.size(); ++i) { - ET_CHECK_OR_RETURN_ERROR( - !inputs[i].isNone(), InvalidArgument, "input %zu is none", i); - } - ET_CHECK_OK_OR_RETURN_ERROR( - method->set_inputs(executorch::aten::ArrayRef( - inputs.data(), inputs.size()))); ET_CHECK_OK_OR_RETURN_ERROR(method->execute()); - const auto outputs_size = method->outputs_size(); std::vector outputs(outputs_size); ET_CHECK_OK_OR_RETURN_ERROR( @@ -268,23 +249,17 @@ runtime::Error Module::set_input( const runtime::EValue& input_value, size_t input_index) { ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); - methods_.at(method_name).inputs.at(input_index) = input_value; - return runtime::Error::Ok; + auto& method = methods_.at(method_name).method; + return method->set_input(input_value, input_index); } runtime::Error Module::set_inputs( const std::string& method_name, const std::vector& input_values) { ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); - auto& inputs = methods_.at(method_name).inputs; - ET_CHECK_OR_RETURN_ERROR( - inputs.size() == input_values.size(), - InvalidArgument, - "input size: %zu does not match method input size: %zu", - input_values.size(), - inputs.size()); - inputs = input_values; - return runtime::Error::Ok; + auto& method = methods_.at(method_name).method; + return method->set_inputs(executorch::aten::ArrayRef( + input_values.data(), input_values.size())); } runtime::Error Module::set_output( diff --git a/extension/module/module.h b/extension/module/module.h index 9177eb9c95d..9350cdd3026 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -522,7 +522,6 @@ class Module { std::unique_ptr planned_memory; std::unique_ptr memory_manager; std::unique_ptr method; - std::vector inputs; }; std::string file_path_; diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 8e6e7fa6c7b..9623e5a6745 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -267,7 +267,7 @@ TEST_F(ModuleTest, TestForward) { EXPECT_TENSOR_CLOSE(result->at(0).toTensor(), *expected.get()); auto tensor2 = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 5.f}); - const auto result2 = module->forward({tensor2, tensor2}); + const auto result2 = module->forward({tensor2, tensor2, 1.0}); EXPECT_EQ(result2.error(), Error::Ok); const auto expected2 = make_tensor_ptr({2, 2}, {4.f, 6.f, 8.f, 10.f}); diff --git a/extension/runner_util/test/inputs_test.cpp b/extension/runner_util/test/inputs_test.cpp index 7d6799fa9ab..aa3af2e145b 100644 --- a/extension/runner_util/test/inputs_test.cpp +++ b/extension/runner_util/test/inputs_test.cpp @@ -75,6 +75,8 @@ class InputsTest : public ::testing::Test { TEST_F(InputsTest, Smoke) { Result input_buffers = prepare_input_tensors(*method_); ASSERT_EQ(input_buffers.error(), Error::Ok); + auto input_err = method_->set_input(executorch::runtime::EValue(1.0), 2); + ASSERT_EQ(input_err, Error::Ok); // We can't look at the input tensors, but we can check that the outputs make // sense after executing the method. diff --git a/runtime/core/error.h b/runtime/core/error.h index 0450476ea93..b75f107314d 100644 --- a/runtime/core/error.h +++ b/runtime/core/error.h @@ -205,42 +205,37 @@ using ::executorch::runtime::error_code_t; * @param[in] ... Optional format string for the log error message and its * arguments. */ -#define ET_CHECK_OK_OR_RETURN_ERROR(error__, ...) \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__) - -// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \ - __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \ - (__VA_ARGS__) +#define ET_CHECK_OK_OR_RETURN_ERROR(...) \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(__VA_ARGS__) /** * Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. * This macro selects the correct version of * ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR based on the number of arguments passed. - * It uses a trick with the preprocessor to count the number of arguments and - * then selects the appropriate macro. - * - * The macro expansion uses __VA_ARGS__ to accept any number of arguments and - * then appends them to ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_, followed by the - * count of arguments. The count is determined by the macro - * ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT which takes the arguments and - * passes them along with a sequence of numbers (2, 1). The preprocessor then - * matches this sequence to the correct number of arguments provided. - * - * If two arguments are passed, ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 is - * selected, suitable for cases where an error code and a custom message are - * provided. If only one argument is passed, - * ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1 is selected, which is used for cases - * with just an error code. - * - * Usage: - * ET_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1 - * ET_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2 + * It uses a helper that reliably picks the 1-arg or 2+-arg form on + * MSVC/Clang/GCC. */ -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \ - _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N +#define ET_INTERNAL_EXPAND(x) x +#define ET_INTERNAL_GET_MACRO( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, NAME, ...) \ + NAME + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +// Picks _2 for 2..10 args, _1 for exactly 1 arg. +#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \ + ET_INTERNAL_EXPAND(ET_INTERNAL_GET_MACRO( \ + __VA_ARGS__, \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 10 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 9 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 8 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 7 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 6 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 5 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 4 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 3 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 2 */ \ + ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1 /* 1 */ \ + )(__VA_ARGS__)) // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. #define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \ @@ -260,21 +255,3 @@ using ::executorch::runtime::error_code_t; return et_error__; \ } \ } while (0) - -// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \ - ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 7d35ebe5054..2be5b92f418 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -407,11 +407,22 @@ Error Method::parse_values(const NamedDataMap* external_data_map) { auto flatbuffer_values = serialization_plan_->values(); ET_CHECK_OR_RETURN_ERROR( flatbuffer_values != nullptr, InvalidProgram, "Missing values"); - size_t n_value = flatbuffer_values->size(); + const size_t n_value = flatbuffer_values->size(); values_ = memory_manager_->method_allocator()->allocateList(n_value); if (values_ == nullptr) { return Error::MemoryAllocationFailed; } + const size_t n_input = inputs_size(); + if (n_input > 0) { + input_set_ = + memory_manager_->method_allocator()->allocateList(n_input); + if (input_set_ == nullptr) { + return Error::MemoryAllocationFailed; + } + for (size_t i = 0; i < n_input; ++i) { + input_set_[i] = false; + } + } // Count the number of tensors marked as EXTERNAL for this method. The actual // number of external constants may be smaller, eg. if multiple tensors point @@ -1076,26 +1087,22 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) { executorch::runtime::toString(t_src.scalar_type())); // Reset the shape for the Method's input as the size of forwarded input // tensor for shape dynamism. Also is a safety check if need memcpy. - Error err = resize_tensor(t_dst, t_src.sizes()); - ET_CHECK_OR_RETURN_ERROR( - err == Error::Ok, - InvalidArgument, - "Error setting input %" ET_PRIsize_t ": 0x%" PRIx32, - input_idx, - static_cast(err)); - Error error; + ET_CHECK_OK_OR_RETURN_ERROR( + resize_tensor(t_dst, t_src.sizes()), + "Error resizing tensor at input %" ET_PRIsize_t, + input_idx); auto tensor_meta = this->method_meta().input_tensor_meta(input_idx); if (tensor_meta->is_memory_planned()) { - error = internal::copy_tensor_data(t_dst, t_src); + ET_CHECK_OK_OR_RETURN_ERROR( + internal::copy_tensor_data(t_dst, t_src), + "Error copying tensor data at input %" ET_PRIsize_t, + input_idx); } else { - error = internal::share_tensor_data(t_dst, t_src); + ET_CHECK_OK_OR_RETURN_ERROR( + internal::share_tensor_data(t_dst, t_src), + "Error sharing tensor data at input %" ET_PRIsize_t, + input_idx); } - ET_CHECK_OR_RETURN_ERROR( - error == Error::Ok, - InvalidArgument, - "Error setting data_ptr %" ET_PRIsize_t ": 0x%" PRIx32, - input_idx, - static_cast(error)); // Prims have to be the same as what was traced } else if (e.isInt()) { ET_CHECK_OR_RETURN_ERROR( @@ -1163,35 +1170,16 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) { return Error::InvalidArgument; } + input_set_[input_idx] = true; + return Error::Ok; } ET_NODISCARD Error Method::set_inputs(const executorch::aten::ArrayRef& input_evalues) { - ET_CHECK_OR_RETURN_ERROR( - initialized(), - InvalidState, - "Inputs can not be set until method has been initialized."); - - ET_CHECK_OR_RETURN_ERROR( - step_state_.instr_idx == 0 && step_state_.chain_idx == 0, - InvalidState, - "Inputs can not be set mid execution."); - - size_t input_size = inputs_size(); - ET_CHECK_OR_RETURN_ERROR( - input_size == input_evalues.size(), - InvalidArgument, - "The length of given input array (%" ET_PRIsize_t - ") must be same as the number of inputs in method (%" ET_PRIsize_t ").", - input_evalues.size(), - input_size); - - for (size_t i = 0; i < input_size; i++) { - Error status = set_input(input_evalues[i], i); - if (status != Error::Ok) { - return status; - } + const size_t n_input = inputs_size(); + for (size_t i = 0; i < n_input; ++i) { + ET_CHECK_OK_OR_RETURN_ERROR(set_input(input_evalues[i], i)); } return Error::Ok; } @@ -1284,20 +1272,21 @@ ET_NODISCARD Error Method::get_inputs(EValue* input_evalues, size_t length) { initialized(), InvalidState, "Inputs can not be retrieved until method has been initialized."); - + const size_t n_input = inputs_size(); ET_CHECK_OR_RETURN_ERROR( - length >= inputs_size(), + length >= n_input, InvalidArgument, "The given array is not large enough to hold all inputs."); - for (size_t i = 0; i < inputs_size(); i++) { + for (size_t i = 0; i < n_input; ++i) { input_evalues[i] = values_[get_input_index(i)]; + // Accessing inputs this way is deprecated. + // We assume the users to be responsible to set the inputs they get. + input_set_[i] = true; } - - for (size_t i = inputs_size(); i < length; i++) { + for (size_t i = n_input; i < length; ++i) { input_evalues[i] = EValue(); } - return Error::Ok; } @@ -1545,6 +1534,14 @@ Error Method::execute() { initialized(), NotSupported, "Cannot execute until method has been initialized."); + const size_t n_input = inputs_size(); + for (size_t i = 0; i < n_input; ++i) { + ET_CHECK_OR_RETURN_ERROR( + input_set_[i], + InvalidArgument, + "Input %" ET_PRIsize_t " has not been set.", + i); + } ET_LOG(Debug, "Executing method: %s.", method_meta().name()); // Chains are executed sequentially today, but future async designs may @@ -1622,10 +1619,16 @@ size_t Method::get_input_index(size_t i) const { } const EValue& Method::get_input(size_t i) const { + // Accessing inputs this way is deprecated. + // We assume the users to be responsible to set the inputs they get. + input_set_[i] = true; return get_value(get_input_index(i)); } EValue& Method::mutable_input(size_t i) { + // Accessing inputs this way is deprecated. + // We assume the users to be responsible to set the inputs they get. + input_set_[i] = true; return mutable_value(get_input_index(i)); } diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 30f1cd44f62..78b71945a5a 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -73,6 +73,7 @@ class Method final { event_tracer_(rhs.event_tracer_), n_value_(rhs.n_value_), values_(rhs.values_), + input_set_(rhs.input_set_), n_delegate_(rhs.n_delegate_), delegates_(rhs.delegates_), n_chains_(rhs.n_chains_), @@ -85,6 +86,7 @@ class Method final { // anything twice. rhs.n_value_ = 0; rhs.values_ = nullptr; + rhs.input_set_ = nullptr; rhs.n_delegate_ = 0; rhs.delegates_ = nullptr; @@ -181,6 +183,9 @@ class Method final { ET_NODISCARD Error get_outputs(EValue* output_evalues, size_t length); /** + * DEPRECATED: Use MethodMeta instead to access metadata, and set_input to + * update Method inputs. + * * Copies the method's inputs into the provided array. * * WARNING: The input contains shallow copies of internal tensor inputs. @@ -194,7 +199,8 @@ class Method final { * * @returns Error::Ok on success, non-Ok on failure. */ - ET_NODISCARD Error get_inputs(EValue* input_evalues, size_t length); + ET_DEPRECATED ET_NODISCARD Error + get_inputs(EValue* input_evalues, size_t length); /** * @@ -314,6 +320,7 @@ class Method final { event_tracer_(event_tracer), n_value_(0), values_(nullptr), + input_set_(nullptr), n_delegate_(0), delegates_(nullptr), n_chains_(0), @@ -362,6 +369,7 @@ class Method final { size_t n_value_; EValue* values_; + bool* input_set_; size_t n_delegate_; BackendDelegate* delegates_; diff --git a/runtime/executor/test/allocation_failure_stress_test.cpp b/runtime/executor/test/allocation_failure_stress_test.cpp index 8d9614c8580..37f3a519f8a 100644 --- a/runtime/executor/test/allocation_failure_stress_test.cpp +++ b/runtime/executor/test/allocation_failure_stress_test.cpp @@ -88,6 +88,8 @@ TEST_F(AllocationFailureStressTest, End2EndIncreaseRuntimeMemUntilSuccess) { // once load was successful. auto input_cleanup = prepare_input_tensors(*method); ASSERT_EQ(input_cleanup.error(), Error::Ok); + auto input_err = method->set_input(executorch::runtime::EValue(1.0), 2); + ASSERT_EQ(input_err, Error::Ok); err = method->execute(); ASSERT_EQ(err, Error::Ok); } @@ -123,6 +125,8 @@ TEST_F(AllocationFailureStressTest, End2EndNonConstantMemUntilSuccess) { // once load was successful. auto input_cleanup = prepare_input_tensors(*method); ASSERT_EQ(input_cleanup.error(), Error::Ok); + auto input_err = method->set_input(executorch::runtime::EValue(1.0), 2); + ASSERT_EQ(input_err, Error::Ok); err = method->execute(); ASSERT_EQ(err, Error::Ok); } diff --git a/runtime/executor/test/backend_data_separation_test.cpp b/runtime/executor/test/backend_data_separation_test.cpp index 32daf3686fc..f6af25c803b 100644 --- a/runtime/executor/test/backend_data_separation_test.cpp +++ b/runtime/executor/test/backend_data_separation_test.cpp @@ -95,6 +95,21 @@ TEST_F(BackendDataSeparationTest, TestSeparation) { /*named_data_map=*/linear_data_map_.get()); ASSERT_EQ(method.error(), Error::Ok); + // Set a dummy input. + int32_t sizes[1] = {3}; + uint8_t dim_order[1] = {0}; + int32_t strides[1] = {1}; + executorch::aten::TensorImpl impl( + executorch::aten::ScalarType::Float, + 1, + sizes, + nullptr, + dim_order, + strides); + auto input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 0); + ASSERT_EQ(input_err, Error::Ok); + // Can execute the method. Error err = method->execute(); ASSERT_EQ(err, Error::Ok); diff --git a/runtime/executor/test/backend_integration_test.cpp b/runtime/executor/test/backend_integration_test.cpp index e2e61f171eb..59e08ea72c5 100644 --- a/runtime/executor/test/backend_integration_test.cpp +++ b/runtime/executor/test/backend_integration_test.cpp @@ -603,6 +603,25 @@ TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) { ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); Result method = program->load_method("forward", &mmm.get()); EXPECT_TRUE(method.ok()); + + int32_t sizes[2] = {2, 2}; + uint8_t dim_order[2] = {0, 1}; + int32_t strides[2] = {2, 1}; + executorch::aten::TensorImpl impl( + executorch::aten::ScalarType::Float, + 2, + sizes, + nullptr, + dim_order, + strides); + auto input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 0); + input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 1); + input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 2); + ASSERT_EQ(input_err, Error::Ok); + Error err = method->execute(); ASSERT_EQ(err, Error::Ok); } diff --git a/runtime/executor/test/kernel_integration_test.cpp b/runtime/executor/test/kernel_integration_test.cpp index 8a855817770..14fcb1c5260 100644 --- a/runtime/executor/test/kernel_integration_test.cpp +++ b/runtime/executor/test/kernel_integration_test.cpp @@ -248,6 +248,8 @@ class KernelIntegrationTest : public ::testing::Test { ASSERT_EQ(inputs_cleanup.error(), Error::Ok); inputs_cleanup_ = std::make_unique( std::move(*inputs_cleanup)); + auto input_err = method_->set_input(executorch::runtime::EValue(1.0), 2); + ASSERT_EQ(input_err, Error::Ok); } void TearDown() override { diff --git a/runtime/executor/test/method_test.cpp b/runtime/executor/test/method_test.cpp index f597746e0fd..60f4e096bac 100644 --- a/runtime/executor/test/method_test.cpp +++ b/runtime/executor/test/method_test.cpp @@ -104,9 +104,13 @@ TEST_F(MethodTest, MoveTest) { Result method = programs_["add"]->load_method("forward", &mmm.get()); ASSERT_EQ(method.error(), Error::Ok); - // Can execute the method. + // Set dummy inputs. auto input_cleanup = prepare_input_tensors(*method); ASSERT_EQ(input_cleanup.error(), Error::Ok); + auto input_err = method->set_input(executorch::runtime::EValue(1.0), 2); + ASSERT_EQ(input_err, Error::Ok); + + // Can execute the method. Error err = method->execute(); ASSERT_EQ(err, Error::Ok); @@ -312,6 +316,21 @@ TEST_F(MethodTest, ConstantSegmentTest) { programs_["add_mul"]->load_method("forward", &mmm.get()); ASSERT_EQ(method.error(), Error::Ok); + // Set a dummy input. + int32_t sizes[2] = {2, 2}; + uint8_t dim_order[2] = {0, 1}; + int32_t strides[2] = {2, 1}; + executorch::aten::TensorImpl impl( + executorch::aten::ScalarType::Float, + 2, + sizes, + nullptr, + dim_order, + strides); + auto input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 0); + ASSERT_EQ(input_err, Error::Ok); + // Can execute the method. Error err = method->execute(); ASSERT_EQ(err, Error::Ok); @@ -324,6 +343,21 @@ TEST_F(MethodTest, ConstantBufferTest) { programs_["linear_constant_buffer"]->load_method("forward", &mmm.get()); ASSERT_EQ(method.error(), Error::Ok); + // Set a dummy input. + int32_t sizes[2] = {2, 2}; + uint8_t dim_order[2] = {0, 1}; + int32_t strides[2] = {2, 1}; + executorch::aten::TensorImpl impl( + executorch::aten::ScalarType::Float, + 2, + sizes, + nullptr, + dim_order, + strides); + auto input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 0); + ASSERT_EQ(input_err, Error::Ok); + // Can execute the method. Error err = method->execute(); ASSERT_EQ(err, Error::Ok); @@ -335,6 +369,21 @@ TEST_F(MethodTest, ProgramDataSeparationTest) { "forward", &mmm.get(), nullptr, data_maps_["add_mul_data"].get()); ASSERT_EQ(method.error(), Error::Ok); + // Set a dummy input. + int32_t sizes[2] = {2, 2}; + uint8_t dim_order[2] = {0, 1}; + int32_t strides[2] = {2, 1}; + executorch::aten::TensorImpl impl( + executorch::aten::ScalarType::Float, + 2, + sizes, + nullptr, + dim_order, + strides); + auto input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 0); + ASSERT_EQ(input_err, Error::Ok); + // Can execute the method. Error err = method->execute(); ASSERT_EQ(err, Error::Ok); @@ -357,6 +406,21 @@ TEST_F(MethodTest, MethodGetAttributeTest) { // expect data to be set EXPECT_EQ(res->const_data_ptr(), &data); + // Set a dummy input. + int32_t sizes[1] = {1}; + uint8_t dim_order[1] = {0}; + int32_t strides[1] = {1}; + executorch::aten::TensorImpl impl( + executorch::aten::ScalarType::Float, + 1, + sizes, + nullptr, + dim_order, + strides); + auto input_err = method->set_input( + executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 0); + ASSERT_EQ(input_err, Error::Ok); + // Can execute the method. Error err = method->execute(); ASSERT_EQ(err, Error::Ok);