@@ -43,7 +43,6 @@ TrainingModule::execute_forward_backward(
4343 uint64_t param_start = param_res.get ()[0 ].toInt ();
4444
4545 // Execute the forward and backward pass.
46-
4746 auto outputs = torch::executor::Module::execute (method_name, input);
4847 if (!outputs.ok ()) {
4948 return outputs.error ();
@@ -56,19 +55,23 @@ TrainingModule::execute_forward_backward(
5655 user_outputs.push_back (outputs.get ().at (i));
5756 }
5857
59- // Extract and store the gradients.
58+ // Extract and store the gradients and params if this is the first time seeing
59+ // this method.
6060 if (method_named_gradients_.find (method_name) ==
6161 method_named_gradients_.end ()) {
62+ // Fully qualified names
63+ std::vector<runtime::EValue> fqn_list;
6264 method_named_gradients_.insert ({method_name, {}});
6365
6466 auto & gradients_map = method_named_gradients_.at (method_name);
65- // Get names.
67+
68+ // Get names if we havent seen this method before.
6669 const std::string fqn_method_name = fqn_method_prefix + method_name;
6770 auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
6871 if (!fqn_res.ok ()) {
6972 return fqn_res.error ();
7073 }
71- const auto & fqn_list = fqn_res.get ();
74+ fqn_list = fqn_res.get ();
7275
7376 // Only have to initialize the dict once because the tensors in the dict and
7477 // the tensors in the method alias the same TensorImpl, so updating one will
@@ -87,43 +90,49 @@ TrainingModule::execute_forward_backward(
8790runtime::Result<
8891 const std::map<executorch::aten::string_view, executorch::aten::Tensor>>
8992TrainingModule::named_parameters (const std::string& method_name) {
90- std::map<executorch::aten::string_view, executorch::aten::Tensor>
91- named_parameters;
92- const std::string fqn_method_name = fqn_method_prefix + method_name;
93- const std::string parameters_method_name =
94- parameters_method_prefix + method_name;
93+ // If we haven't seen this method before, populate the dict.
94+ if (method_named_parameters_.find (method_name) ==
95+ method_named_parameters_.end ()) {
96+ const std::string fqn_method_name = fqn_method_prefix + method_name;
97+ const std::string parameters_method_name =
98+ parameters_method_prefix + method_name;
9599
96- // get names.
97- auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
98- if (!fqn_res.ok ()) {
99- return fqn_res.error ();
100- }
101- const auto & fqn_list = fqn_res.get ();
100+ method_named_parameters_.insert ({method_name, {}});
102101
103- // get params start .
104- auto param_res =
105- executorch::extension::Module::execute (parameters_method_name);
106- if (!param_res. ok ()) {
107- return param_res. error ();
108- }
102+ // get names .
103+ auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
104+ if (!fqn_res. ok ()) {
105+ return fqn_res. error ();
106+ }
107+ const auto & fqn_list = fqn_res. get ();
109108
110- uint64_t param_start = param_res.get ()[0 ].toInt ();
109+ // get params start.
110+ auto param_res =
111+ executorch::extension::Module::execute (parameters_method_name);
112+ if (!param_res.ok ()) {
113+ return param_res.error ();
114+ }
111115
112- auto e = executorch::extension::Module::load_method (method_name);
113- if (e != runtime::Error::Ok) {
114- return e;
115- }
116- auto & method = methods_.at (method_name).method ;
117-
118- // create dict
119- size_t name_index = 0 ;
120- for (size_t param_index = param_start; param_index < method->outputs_size ();
121- ++param_index, ++name_index) {
122- executorch::aten::string_view fqn = fqn_list.at (name_index).toString ();
123- executorch::aten::Tensor param = method->get_output (param_index).toTensor ();
124- named_parameters.insert ({fqn, param});
116+ uint64_t param_start = param_res.get ()[0 ].toInt ();
117+
118+ // Load the method if it is not already loaded.
119+ auto e = executorch::extension::Module::load_method (method_name);
120+ if (e != runtime::Error::Ok) {
121+ return e;
122+ }
123+ auto & method = methods_.at (method_name).method ;
124+
125+ // populate dict
126+ size_t name_index = 0 ;
127+ for (size_t param_index = param_start; param_index < method->outputs_size ();
128+ ++param_index, ++name_index) {
129+ executorch::aten::string_view fqn = fqn_list.at (name_index).toString ();
130+ executorch::aten::Tensor param =
131+ method->get_output (param_index).toTensor ();
132+ method_named_parameters_.at (method_name).insert ({fqn, param});
133+ }
125134 }
126- return named_parameters ;
135+ return method_named_parameters_. at (method_name) ;
127136}
128137
129138runtime::Result<
0 commit comments