@@ -199,3 +199,44 @@ TEST_F(TrainingModuleTest, DataExternalConstantsTest) {
199199 ASSERT_EQ (attributes.find (" b" )->second .sizes ()[0 ], 2 );
200200 ASSERT_EQ (attributes.find (" b" )->second .dim (), 2 );
201201}
202+
203+ TEST_F (TrainingModuleTest, UnloadMethodTest) {
204+ const char * ptd_path = std::getenv (" ET_MODULE_TRAIN_DATA_PATH" );
205+ Result<FileDataLoader> data_map_loader_res = FileDataLoader::from (ptd_path);
206+ ASSERT_EQ (data_map_loader_res.error (), Error::Ok);
207+
208+ auto data_map_loader =
209+ std::make_unique<torch::executor::util::FileDataLoader>(
210+ std::move (data_map_loader_res.get ()));
211+
212+ const char * pte_path = std::getenv (" ET_MODULE_TRAIN_PROGRAM_PATH" );
213+ Result<FileDataLoader> pte_loader_res = FileDataLoader::from (pte_path);
214+ ASSERT_EQ (pte_loader_res.error (), Error::Ok);
215+
216+ auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
217+ std::move (pte_loader_res.get ()));
218+
219+ auto mod = executorch::extension::training::TrainingModule (
220+ std::move (pte_loader),
221+ nullptr ,
222+ nullptr ,
223+ nullptr ,
224+ std::move (data_map_loader));
225+
226+ auto parameters_res = mod.named_parameters (" forward" );
227+ ASSERT_EQ (parameters_res.error (), Error::Ok);
228+ auto & parameters = parameters_res.get ();
229+
230+ ASSERT_NEAR (parameters_res.get ().find (" linear.bias" )->second .const_data_ptr <float >()[0 ], 0.1528 , 0.0001 );
231+
232+ // mock training
233+ auto linear_bias_ptr = parameters.find (" linear.bias" )->second .mutable_data_ptr <float >();
234+ linear_bias_ptr[0 ] += 0.5 ;
235+ ASSERT_NEAR (parameters.find (" linear.bias" )->second .const_data_ptr <float >()[0 ], 0.6528 , 0.0001 );
236+
237+ mod.unload_method (" forward" );
238+
239+ auto new_parameters_res = mod.named_parameters (" forward" );
240+ ASSERT_EQ (new_parameters_res.error (), Error::Ok);
241+ ASSERT_NEAR (new_parameters_res.get ().find (" linear.bias" )->second .const_data_ptr <float >()[0 ], 0.1528 , 0.0001 );
242+ }
0 commit comments