@@ -43,7 +43,6 @@ TrainingModule::execute_forward_backward(
4343 uint64_t param_start = param_res.get ()[0 ].toInt ();
4444
4545 // Execute the forward and backward pass.
46-
4746 auto outputs = torch::executor::Module::execute (method_name, input);
4847 if (!outputs.ok ()) {
4948 return outputs.error ();
@@ -56,19 +55,22 @@ TrainingModule::execute_forward_backward(
5655 user_outputs.push_back (outputs.get ().at (i));
5756 }
5857
59- // Extract and store the gradients.
58+ // Extract and store the gradients and params if this is the first time seeing this method .
6059 if (method_named_gradients_.find (method_name) ==
6160 method_named_gradients_.end ()) {
61+ // Fully qualified names
62+ std::vector<runtime::EValue> fqn_list;
6263 method_named_gradients_.insert ({method_name, {}});
6364
6465 auto & gradients_map = method_named_gradients_.at (method_name);
65- // Get names.
66+
67+ // Get names if we havent seen this method before.
6668 const std::string fqn_method_name = fqn_method_prefix + method_name;
6769 auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
6870 if (!fqn_res.ok ()) {
6971 return fqn_res.error ();
7072 }
71- const auto & fqn_list = fqn_res.get ();
73+ fqn_list = fqn_res.get ();
7274
7375 // Only have to initialize the dict once because the tensors in the dict and
7476 // the tensors in the method alias the same TensorImpl, so updating one will
@@ -87,43 +89,48 @@ TrainingModule::execute_forward_backward(
8789runtime::Result<
8890 const std::map<executorch::aten::string_view, executorch::aten::Tensor>>
8991TrainingModule::named_parameters (const std::string& method_name) {
90- std::map<executorch::aten::string_view, executorch::aten::Tensor>
91- named_parameters;
92- const std::string fqn_method_name = fqn_method_prefix + method_name;
93- const std::string parameters_method_name =
94- parameters_method_prefix + method_name;
92+ // If we haven't seen this method before, populate the dict.
93+ if (method_named_parameters_.find (method_name) ==
94+ method_named_parameters_.end ()) {
95+ const std::string fqn_method_name = fqn_method_prefix + method_name;
96+ const std::string parameters_method_name =
97+ parameters_method_prefix + method_name;
9598
96- // get names.
97- auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
98- if (!fqn_res.ok ()) {
99- return fqn_res.error ();
100- }
101- const auto & fqn_list = fqn_res.get ();
99+ method_named_parameters_.insert ({method_name, {}});
102100
103- // get params start .
104- auto param_res =
105- executorch::extension::Module::execute (parameters_method_name);
106- if (!param_res. ok ()) {
107- return param_res. error ();
108- }
101+ // get names .
102+ auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
103+ if (!fqn_res. ok ()) {
104+ return fqn_res. error ();
105+ }
106+ const auto & fqn_list = fqn_res. get ();
109107
110- uint64_t param_start = param_res.get ()[0 ].toInt ();
108+ // get params start.
109+ auto param_res =
110+ executorch::extension::Module::execute (parameters_method_name);
111+ if (!param_res.ok ()) {
112+ return param_res.error ();
113+ }
111114
112- auto e = executorch::extension::Module::load_method (method_name);
113- if (e != runtime::Error::Ok) {
114- return e;
115- }
116- auto & method = methods_.at (method_name).method ;
117-
118- // create dict
119- size_t name_index = 0 ;
120- for (size_t param_index = param_start; param_index < method->outputs_size ();
121- ++param_index, ++name_index) {
122- executorch::aten::string_view fqn = fqn_list.at (name_index).toString ();
123- executorch::aten::Tensor param = method->get_output (param_index).toTensor ();
124- named_parameters.insert ({fqn, param});
115+ uint64_t param_start = param_res.get ()[0 ].toInt ();
116+
117+ // Load the method if it is not already loaded.
118+ auto e = executorch::extension::Module::load_method (method_name);
119+ if (e != runtime::Error::Ok) {
120+ return e;
121+ }
122+ auto & method = methods_.at (method_name).method ;
123+
124+ // populate dict
125+ size_t name_index = 0 ;
126+ for (size_t param_index = param_start; param_index < method->outputs_size ();
127+ ++param_index, ++name_index) {
128+ executorch::aten::string_view fqn = fqn_list.at (name_index).toString ();
129+ executorch::aten::Tensor param = method->get_output (param_index).toTensor ();
130+ method_named_parameters_.at (method_name).insert ({fqn, param});
131+ }
125132 }
126- return named_parameters ;
133+ return method_named_parameters_. at (method_name) ;
127134}
128135
129136runtime::Result<
0 commit comments