Skip to content

Commit 3c2ec29

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Set inputs directly on Method without caching. (#13215)
Summary: . Differential Revision: D79850621
1 parent 2e9fc4a commit 3c2ec29

File tree

3 files changed

+8
-34
lines changed

3 files changed

+8
-34
lines changed

extension/module/module.cpp

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ runtime::Error Module::load_method(
222222
method_holder.memory_manager.get(),
223223
event_tracer ? event_tracer : this->event_tracer(),
224224
data_map_.get()));
225-
method_holder.inputs.resize(method_holder.method->inputs_size());
226225
methods_.emplace(method_name, std::move(method_holder));
227226
}
228227
return runtime::Error::Ok;
@@ -245,28 +244,10 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
245244
const std::vector<runtime::EValue>& input_values) {
246245
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
247246
auto& method = methods_.at(method_name).method;
248-
auto& inputs = methods_.at(method_name).inputs;
249-
250-
ET_CHECK_OR_RETURN_ERROR(
251-
input_values.size() <= inputs.size(),
252-
InvalidArgument,
253-
"input size: %zu does not match method input size: %zu",
254-
input_values.size(),
255-
inputs.size());
256-
for (size_t i = 0; i < input_values.size(); ++i) {
257-
if (!input_values[i].isNone()) {
258-
inputs[i] = input_values[i];
259-
}
247+
for (auto index = 0; index < input_values.size(); ++index) {
248+
ET_CHECK_OK_OR_RETURN_ERROR(method->set_input(input_values[index], index));
260249
}
261-
for (size_t i = 0; i < inputs.size(); ++i) {
262-
ET_CHECK_OR_RETURN_ERROR(
263-
!inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
264-
}
265-
ET_CHECK_OK_OR_RETURN_ERROR(
266-
method->set_inputs(executorch::aten::ArrayRef<runtime::EValue>(
267-
inputs.data(), inputs.size())));
268250
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
269-
270251
const auto outputs_size = method->outputs_size();
271252
std::vector<runtime::EValue> outputs(outputs_size);
272253
ET_CHECK_OK_OR_RETURN_ERROR(
@@ -280,23 +261,17 @@ runtime::Error Module::set_input(
280261
const runtime::EValue& input_value,
281262
size_t input_index) {
282263
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
283-
methods_.at(method_name).inputs.at(input_index) = input_value;
284-
return runtime::Error::Ok;
264+
auto& method = methods_.at(method_name).method;
265+
return method->set_input(input_value, input_index);
285266
}
286267

287268
runtime::Error Module::set_inputs(
288269
const std::string& method_name,
289270
const std::vector<runtime::EValue>& input_values) {
290271
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
291-
auto& inputs = methods_.at(method_name).inputs;
292-
ET_CHECK_OR_RETURN_ERROR(
293-
inputs.size() == input_values.size(),
294-
InvalidArgument,
295-
"input size: %zu does not match method input size: %zu",
296-
input_values.size(),
297-
inputs.size());
298-
inputs = input_values;
299-
return runtime::Error::Ok;
272+
auto& method = methods_.at(method_name).method;
273+
return method->set_inputs(executorch::aten::ArrayRef<runtime::EValue>(
274+
input_values.data(), input_values.size()));
300275
}
301276

302277
runtime::Error Module::set_output(

extension/module/module.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,6 @@ class Module {
522522
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
523523
std::unique_ptr<runtime::MemoryManager> memory_manager;
524524
std::unique_ptr<Method> method;
525-
std::vector<runtime::EValue> inputs;
526525
};
527526

528527
std::string file_path_;

extension/module/test/module_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ TEST_F(ModuleTest, TestForward) {
267267
EXPECT_TENSOR_CLOSE(result->at(0).toTensor(), *expected.get());
268268

269269
auto tensor2 = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 5.f});
270-
const auto result2 = module->forward({tensor2, tensor2});
270+
const auto result2 = module->forward({tensor2, tensor2, 1.0});
271271
EXPECT_EQ(result2.error(), Error::Ok);
272272

273273
const auto expected2 = make_tensor_ptr({2, 2}, {4.f, 6.f, 8.f, 10.f});

0 commit comments

Comments
 (0)