Skip to content

Commit c359d73

Browse files
silverguofacebook-github-bot
authored andcommitted
Add unload method to module (#12984)
Summary: We would like to do grid search on device when training the model, ideally we can unload/load the model weights after trained with each specific hyperparameter set. Differential Revision: D79184972
1 parent 8b204c0 commit c359d73

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

extension/module/module.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ runtime::Error Module::load_method(
228228
return runtime::Error::Ok;
229229
}
230230

231+
bool Module::unload_method(const std::string& method_name) {
232+
return methods_.erase(method_name);
233+
}
234+
231235
ET_NODISCARD runtime::Result<Method*> Module::method(
232236
const std::string& method_name) {
233237
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));

extension/module/module.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,15 @@ class Module {
194194
return load_method(method_name, nullptr, event_tracer);
195195
}
196196

197+
/**
198+
* Unload a specific method from the program.
199+
*
200+
* @param[in] method_name The name of the method to unload.
201+
*
202+
* @returns True if the method is unloaded, false if no-op.
203+
*/
204+
bool unload_method(const std::string& method_name);
205+
197206
/**
198207
* Get a method by it's name. Not recommended to use this method directly as
199208
* an end user. It's exposed to allow for composability of module in apis that
@@ -228,6 +237,15 @@ class Module {
228237
return load_forward(nullptr, event_tracer);
229238
}
230239

240+
/**
241+
* Unload the 'forward' method from the program.
242+
*
243+
* @returns True if the 'forward' method is unloaded, false if no-op.
244+
*/
245+
inline bool unload_forward() {
246+
return unload_method("forward");
247+
};
248+
231249
/**
232250
* Checks if a specific method is loaded.
233251
*

extension/module/test/module_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,25 @@ TEST_F(ModuleTest, TestLoadMethod) {
9191
EXPECT_TRUE(module.is_loaded());
9292
}
9393

94+
TEST_F(ModuleTest, TestUnloadMethod) {
95+
Module module(model_path_);
96+
97+
EXPECT_FALSE(module.is_method_loaded("forward"));
98+
const auto errorLoad = module.load_method("forward");
99+
EXPECT_EQ(errorLoad, Error::Ok);
100+
EXPECT_TRUE(module.is_method_loaded("forward"));
101+
// Unload method
102+
EXPECT_TRUE(module.unload_method("forward"));
103+
EXPECT_FALSE(module.is_method_loaded("forward"));
104+
// Try unload method again
105+
EXPECT_FALSE(module.unload_method("forward"));
106+
// Load method again
107+
const auto errorReload = module.load_method("forward");
108+
EXPECT_EQ(errorReload, Error::Ok);
109+
EXPECT_TRUE(module.is_method_loaded("forward"));
110+
EXPECT_TRUE(module.is_loaded());
111+
}
112+
94113
TEST_F(ModuleTest, TestLoadNonExistentMethod) {
95114
Module module(model_path_);
96115

0 commit comments

Comments
 (0)