@@ -154,6 +154,7 @@ runtime::Error Module::load_method(
154154 temp_allocator_.get ());
155155 method_holder.method = ET_UNWRAP_UNIQUE (program_->load_method (
156156 method_name.c_str (), method_holder.memory_manager .get (), tracer));
157+ method_holder.inputs .resize (method_holder.method ->inputs_size ());
157158 methods_.emplace (method_name, std::move (method_holder));
158159 }
159160 return runtime::Error::Ok;
@@ -170,10 +171,19 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
170171 const std::vector<runtime::EValue>& input_values) {
171172 ET_CHECK_OK_OR_RETURN_ERROR (load_method (method_name));
172173 auto & method = methods_.at (method_name).method ;
174+ auto & inputs = methods_.at (method_name).inputs ;
173175
174- ET_CHECK_OK_OR_RETURN_ERROR (
175- method->set_inputs (exec_aten::ArrayRef<runtime::EValue>(
176- input_values.data (), input_values.size ())));
176+ for (size_t i = 0 ; i < input_values.size (); ++i) {
177+ if (!input_values[i].isNone ()) {
178+ inputs[i] = input_values[i];
179+ }
180+ }
181+ for (size_t i = 0 ; i < inputs.size (); ++i) {
182+ ET_CHECK_OR_RETURN_ERROR (
183+ !inputs[i].isNone (), InvalidArgument, " input %zu is none" , i);
184+ }
185+ ET_CHECK_OK_OR_RETURN_ERROR (method->set_inputs (
186+ exec_aten::ArrayRef<runtime::EValue>(inputs.data (), inputs.size ())));
177187 ET_CHECK_OK_OR_RETURN_ERROR (method->execute ());
178188
179189 const auto outputs_size = method->outputs_size ();
@@ -184,6 +194,30 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
184194 return outputs;
185195}
186196
197+ runtime::Error Module::set_input (
198+ const std::string& method_name,
199+ const runtime::EValue& input_value,
200+ size_t input_index) {
201+ ET_CHECK_OK_OR_RETURN_ERROR (load_method (method_name));
202+ methods_.at (method_name).inputs .at (input_index) = input_value;
203+ return runtime::Error::Ok;
204+ }
205+
206+ runtime::Error Module::set_inputs (
207+ const std::string& method_name,
208+ const std::vector<runtime::EValue>& input_values) {
209+ ET_CHECK_OK_OR_RETURN_ERROR (load_method (method_name));
210+ auto & inputs = methods_.at (method_name).inputs ;
211+ ET_CHECK_OR_RETURN_ERROR (
212+ inputs.size () == input_values.size (),
213+ InvalidArgument,
214+ " input size: %zu does not match method input size: %zu" ,
215+ input_values.size (),
216+ inputs.size ());
217+ inputs = input_values;
218+ return runtime::Error::Ok;
219+ }
220+
187221runtime::Error Module::set_output_data_ptr (
188222 runtime::EValue output_value,
189223 size_t output_index,
0 commit comments