Skip to content

Commit b574747

Browse files
silverguofacebook-github-bot
authored andcommitted
Add unload method to module
Summary: We would like to do grid search on device when training the model, ideally we can unload/load the model weights after trained with each specific hyperparameter set. Differential Revision: D79184972
1 parent 8b204c0 commit b574747

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

extension/module/module.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ runtime::Error Module::load_method(
228228
return runtime::Error::Ok;
229229
}
230230

231+
runtime::Error Module::unload_method(const std::string& method_name) {
232+
ET_CHECK_OK_OR_RETURN_ERROR(load());
233+
methods_.erase(method_name);
234+
return runtime::Error::Ok;
235+
}
236+
231237
ET_NODISCARD runtime::Result<Method*> Module::method(
232238
const std::string& method_name) {
233239
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));

extension/module/module.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,16 @@ class Module {
194194
return load_method(method_name, nullptr, event_tracer);
195195
}
196196

197+
/**
198+
* Unload a specific method from the program.
199+
*
200+
* @param[in] method_name The name of the method to unload.
201+
*
202+
* @returns An Error to indicate success or failure.
203+
*/
204+
ET_NODISCARD
205+
runtime::Error unload_method(const std::string& method_name);
206+
197207
/**
198208
* Get a method by it's name. Not recommended to use this method directly as
199209
* an end user. It's exposed to allow for composability of module in apis that
@@ -228,6 +238,16 @@ class Module {
228238
return load_forward(nullptr, event_tracer);
229239
}
230240

241+
/**
242+
* Unload the 'forward' method from the program.
243+
*
244+
* @returns An Error to indicate success or failure.
245+
*/
246+
ET_NODISCARD
247+
runtime::Error unload_forward() {
248+
return unload_method("forward");
249+
};
250+
231251
/**
232252
* Checks if a specific method is loaded.
233253
*

extension/module/test/module_test.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,24 @@ TEST_F(ModuleTest, TestLoadMethod) {
9191
EXPECT_TRUE(module.is_loaded());
9292
}
9393

94+
TEST_F(ModuleTest, TestUnloadMethod) {
95+
Module module(model_path_);
96+
97+
EXPECT_FALSE(module.is_method_loaded("forward"));
98+
const auto errorLoad = module.load_method("forward");
99+
EXPECT_EQ(errorLoad, Error::Ok);
100+
EXPECT_TRUE(module.is_method_loaded("forward"));
101+
// Unload method
102+
const auto errorUnload = module.unload_method("forward");
103+
EXPECT_EQ(errorUnload, Error::Ok);
104+
EXPECT_FALSE(module.is_method_loaded("forward"));
105+
// Load method again
106+
const auto errorReload = module.load_method("forward");
107+
EXPECT_EQ(errorReload, Error::Ok);
108+
EXPECT_TRUE(module.is_method_loaded("forward"));
109+
EXPECT_TRUE(module.is_loaded());
110+
}
111+
94112
TEST_F(ModuleTest, TestLoadNonExistentMethod) {
95113
Module module(model_path_);
96114

0 commit comments

Comments
 (0)