Skip to content

Commit ec42267

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Set inputs directly on Method without caching.
Summary: . Differential Revision: D79850621
1 parent f1118c4 commit ec42267

File tree

3 files changed

+9
-34
lines changed

3 files changed

+9
-34
lines changed

extension/module/module.cpp

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ runtime::Error Module::load_method(
222222
method_holder.memory_manager.get(),
223223
event_tracer ? event_tracer : this->event_tracer(),
224224
data_map_.get()));
225-
method_holder.inputs.resize(method_holder.method->inputs_size());
226225
methods_.emplace(method_name, std::move(method_holder));
227226
}
228227
return runtime::Error::Ok;
@@ -245,28 +244,11 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
245244
const std::vector<runtime::EValue>& input_values) {
246245
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
247246
auto& method = methods_.at(method_name).method;
248-
auto& inputs = methods_.at(method_name).inputs;
249-
250-
ET_CHECK_OR_RETURN_ERROR(
251-
input_values.size() <= inputs.size(),
252-
InvalidArgument,
253-
"input size: %zu does not match method input size: %zu",
254-
input_values.size(),
255-
inputs.size());
256-
for (size_t i = 0; i < input_values.size(); ++i) {
257-
if (!input_values[i].isNone()) {
258-
inputs[i] = input_values[i];
259-
}
247+
for (auto index = 0; index < input_values.size(); ++index) {
248+
ET_CHECK_OK_OR_RETURN_ERROR(
249+
method->set_input(input_values[index], index));
260250
}
261-
for (size_t i = 0; i < inputs.size(); ++i) {
262-
ET_CHECK_OR_RETURN_ERROR(
263-
!inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
264-
}
265-
ET_CHECK_OK_OR_RETURN_ERROR(
266-
method->set_inputs(executorch::aten::ArrayRef<runtime::EValue>(
267-
inputs.data(), inputs.size())));
268251
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
269-
270252
const auto outputs_size = method->outputs_size();
271253
std::vector<runtime::EValue> outputs(outputs_size);
272254
ET_CHECK_OK_OR_RETURN_ERROR(
@@ -280,23 +262,17 @@ runtime::Error Module::set_input(
280262
const runtime::EValue& input_value,
281263
size_t input_index) {
282264
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
283-
methods_.at(method_name).inputs.at(input_index) = input_value;
284-
return runtime::Error::Ok;
265+
auto& method = methods_.at(method_name).method;
266+
return method->set_input(input_value, input_index);
285267
}
286268

287269
runtime::Error Module::set_inputs(
288270
const std::string& method_name,
289271
const std::vector<runtime::EValue>& input_values) {
290272
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
291-
auto& inputs = methods_.at(method_name).inputs;
292-
ET_CHECK_OR_RETURN_ERROR(
293-
inputs.size() == input_values.size(),
294-
InvalidArgument,
295-
"input size: %zu does not match method input size: %zu",
296-
input_values.size(),
297-
inputs.size());
298-
inputs = input_values;
299-
return runtime::Error::Ok;
273+
auto& method = methods_.at(method_name).method;
274+
return method->set_inputs(executorch::aten::ArrayRef<runtime::EValue>(
275+
input_values.data(), input_values.size()));
300276
}
301277

302278
runtime::Error Module::set_output(

extension/module/module.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,6 @@ class Module {
522522
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
523523
std::unique_ptr<runtime::MemoryManager> memory_manager;
524524
std::unique_ptr<Method> method;
525-
std::vector<runtime::EValue> inputs;
526525
};
527526

528527
std::string file_path_;

extension/module/test/module_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ TEST_F(ModuleTest, TestForward) {
267267
EXPECT_TENSOR_CLOSE(result->at(0).toTensor(), *expected.get());
268268

269269
auto tensor2 = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 5.f});
270-
const auto result2 = module->forward({tensor2, tensor2});
270+
const auto result2 = module->forward({tensor2, tensor2, 1.0});
271271
EXPECT_EQ(result2.error(), Error::Ok);
272272

273273
const auto expected2 = make_tensor_ptr({2, 2}, {4.f, 6.f, 8.f, 10.f});

0 commit comments

Comments
 (0)