diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 6c534b8d560..11d71d1ae08 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -240,6 +240,12 @@ runtime::Result> Module::execute( auto& method = methods_.at(method_name).method; auto& inputs = methods_.at(method_name).inputs; + ET_CHECK_OR_RETURN_ERROR( + input_values.size() <= inputs.size(), + InvalidArgument, + "input size: %zu does not match method input size: %zu", + input_values.size(), + inputs.size()); for (size_t i = 0; i < input_values.size(); ++i) { if (!input_values[i].isNone()) { inputs[i] = input_values[i]; diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 6d0b941706d..e0444c2aefb 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -216,6 +216,16 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) { EXPECT_NE(result.error(), Error::Ok); } +TEST_F(ModuleTest, TestExecuteWithTooManyInputs) { + Module module(model_path_); + + auto tensor = make_tensor_ptr({2, 2}, {1.f, 2.f, 3.f, 4.f}); + + const auto result = module.execute("forward", {tensor, tensor, 1.0, 1.0}); + + EXPECT_NE(result.error(), Error::Ok); +} + TEST_F(ModuleTest, TestGet) { Module module(model_path_);