@@ -13,9 +13,19 @@ namespace extension {
1313namespace training {
1414
1515namespace {
16- std::string gradients_method_prefix = " __et_training_gradients_index_" ;
17- std::string parameters_method_prefix = " __et_training_parameters_index_" ;
18- std::string fqn_method_prefix = " __et_training_fqn_" ;
16+
17+ std::string make_parameters_method_name (const std::string& method_name) {
18+ return " __et_training_parameters_index_" + method_name;
19+ }
20+
21+ std::string make_gradients_method_name (const std::string& method_name) {
22+ return " __et_training_gradients_index_" + method_name;
23+ }
24+
25+ std::string make_fqn_method_name (const std::string& method_name) {
26+ return " __et_training_fqn_" + method_name;
27+ }
28+
1929} // namespace
2030
2131runtime::Result<std::vector<runtime::EValue>>
@@ -24,15 +34,15 @@ TrainingModule::execute_forward_backward(
2434 const std::vector<runtime::EValue>& input) {
2535 // Find where the user outputs end.
2636 const std::string gradients_method_name =
27- gradients_method_prefix + method_name;
37+ make_gradients_method_name ( method_name) ;
2838 auto res = executorch::extension::Module::execute (gradients_method_name);
2939 if (!res.ok ()) {
3040 return res.error ();
3141 }
3242 uint64_t grad_start = res.get ()[0 ].toInt ();
3343
3444 const std::string parameters_method_name =
35- parameters_method_prefix + method_name;
45+ make_parameters_method_name ( method_name) ;
3646 // get params start.
3747 auto param_res =
3848 executorch::extension::Module::execute (parameters_method_name);
@@ -66,7 +76,7 @@ TrainingModule::execute_forward_backward(
6676 auto & gradients_map = method_named_gradients_.at (method_name);
6777
6878 // Get names if we havent seen this method before.
69- const std::string fqn_method_name = fqn_method_prefix + method_name;
79+ const std::string fqn_method_name = make_fqn_method_name ( method_name) ;
7080 auto fqn_res = executorch::extension::Module::execute (fqn_method_name);
7181 if (!fqn_res.ok ()) {
7282 return fqn_res.error ();
@@ -92,9 +102,9 @@ TrainingModule::named_parameters(const std::string& method_name) {
92102 // If we haven't seen this method before, populate the dict.
93103 if (method_named_parameters_.find (method_name) ==
94104 method_named_parameters_.end ()) {
95- const std::string fqn_method_name = fqn_method_prefix + method_name;
105+ const std::string fqn_method_name = make_fqn_method_name ( method_name) ;
96106 const std::string parameters_method_name =
97- parameters_method_prefix + method_name;
107+ make_parameters_method_name ( method_name) ;
98108
99109 method_named_parameters_.insert ({method_name, {}});
100110
0 commit comments