Skip to content

Commit 18098a4

Browse files
authored
Set inputs directly on Method without caching. (#13215)
Summary: . Differential Revision: D79850621
1 parent 3d7fd00 commit 18098a4

File tree

3 files changed

+8
-34
lines changed

3 files changed

+8
-34
lines changed

extension/module/module.cpp

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ runtime::Error Module::load_method(
210210
method_holder.memory_manager.get(),
211211
event_tracer ? event_tracer : this->event_tracer(),
212212
data_map_.get()));
213-
method_holder.inputs.resize(method_holder.method->inputs_size());
214213
methods_.emplace(method_name, std::move(method_holder));
215214
}
216215
return runtime::Error::Ok;
@@ -233,28 +232,10 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
233232
const std::vector<runtime::EValue>& input_values) {
234233
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
235234
auto& method = methods_.at(method_name).method;
236-
auto& inputs = methods_.at(method_name).inputs;
237-
238-
ET_CHECK_OR_RETURN_ERROR(
239-
input_values.size() <= inputs.size(),
240-
InvalidArgument,
241-
"input size: %zu does not match method input size: %zu",
242-
input_values.size(),
243-
inputs.size());
244-
for (size_t i = 0; i < input_values.size(); ++i) {
245-
if (!input_values[i].isNone()) {
246-
inputs[i] = input_values[i];
247-
}
235+
for (auto index = 0; index < input_values.size(); ++index) {
236+
ET_CHECK_OK_OR_RETURN_ERROR(method->set_input(input_values[index], index));
248237
}
249-
for (size_t i = 0; i < inputs.size(); ++i) {
250-
ET_CHECK_OR_RETURN_ERROR(
251-
!inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
252-
}
253-
ET_CHECK_OK_OR_RETURN_ERROR(
254-
method->set_inputs(executorch::aten::ArrayRef<runtime::EValue>(
255-
inputs.data(), inputs.size())));
256238
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
257-
258239
const auto outputs_size = method->outputs_size();
259240
std::vector<runtime::EValue> outputs(outputs_size);
260241
ET_CHECK_OK_OR_RETURN_ERROR(
@@ -268,23 +249,17 @@ runtime::Error Module::set_input(
268249
const runtime::EValue& input_value,
269250
size_t input_index) {
270251
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
271-
methods_.at(method_name).inputs.at(input_index) = input_value;
272-
return runtime::Error::Ok;
252+
auto& method = methods_.at(method_name).method;
253+
return method->set_input(input_value, input_index);
273254
}
274255

275256
runtime::Error Module::set_inputs(
276257
const std::string& method_name,
277258
const std::vector<runtime::EValue>& input_values) {
278259
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
279-
auto& inputs = methods_.at(method_name).inputs;
280-
ET_CHECK_OR_RETURN_ERROR(
281-
inputs.size() == input_values.size(),
282-
InvalidArgument,
283-
"input size: %zu does not match method input size: %zu",
284-
input_values.size(),
285-
inputs.size());
286-
inputs = input_values;
287-
return runtime::Error::Ok;
260+
auto& method = methods_.at(method_name).method;
261+
return method->set_inputs(executorch::aten::ArrayRef<runtime::EValue>(
262+
input_values.data(), input_values.size()));
288263
}
289264

290265
runtime::Error Module::set_output(

extension/module/module.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,6 @@ class Module {
522522
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
523523
std::unique_ptr<runtime::MemoryManager> memory_manager;
524524
std::unique_ptr<Method> method;
525-
std::vector<runtime::EValue> inputs;
526525
};
527526

528527
std::string file_path_;

extension/module/test/module_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ TEST_F(ModuleTest, TestForward) {
267267
EXPECT_TENSOR_CLOSE(result->at(0).toTensor(), *expected.get());
268268

269269
auto tensor2 = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 5.f});
270-
const auto result2 = module->forward({tensor2, tensor2});
270+
const auto result2 = module->forward({tensor2, tensor2, 1.0});
271271
EXPECT_EQ(result2.error(), Error::Ok);
272272

273273
const auto expected2 = make_tensor_ptr({2, 2}, {4.f, 6.f, 8.f, 10.f});

0 commit comments

Comments
 (0)