Skip to content

Commit c93a160

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add get_method API to module
Summary: It is for reusing the exention/module in the module definition in pybindings https://github.com/pytorch/executorch/blob/1a918c779e16c0ee903a08b30c1c666d1efb2c57/extension/pybindings/pybindings.cpp#L172 Differential Revision: D71135352
1 parent 79fda31 commit c93a160

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

extension/module/module.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,15 @@ class Module {
459459
return event_tracer_.get();
460460
}
461461

462+
ET_NODISCARD
463+
runtime::Result<runtime::Method*> get_method(const std::string& method_name) {
464+
if (methods_.count(method_name) == 0) {
465+
ET_LOG(Info, "Method %s not found", method_name.c_str());
466+
return runtime::Error::NotFound;
467+
}
468+
return methods_[method_name].method.get();
469+
}
470+
462471
private:
463472
struct MethodHolder {
464473
std::vector<std::vector<uint8_t>> planned_buffers;

extension/module/test/module_test.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ TEST_F(ModuleTest, TestLoadNonExistent) {
5353
EXPECT_FALSE(module.is_loaded());
5454
}
5555

56+
5657
TEST_F(ModuleTest, TestLoadCorruptedFile) {
5758
Module module("/dev/null");
5859
const auto error = module.load();
@@ -95,6 +96,30 @@ TEST_F(ModuleTest, TestLoadNonExistentMethod) {
9596
EXPECT_TRUE(module.is_loaded());
9697
}
9798

99+
TEST_F(ModuleTest, TestGetMethod) {
100+
Module module(model_path_);
101+
102+
const auto error = module.load_method("forward");
103+
EXPECT_EQ(error, Error::Ok);
104+
auto method_res = module.get_method("forward");
105+
EXPECT_EQ(method_res.error(), Error::Ok);
106+
auto method = method_res.get();
107+
EXPECT_EQ(strcmp(method->method_meta().name(), "forward"), 0);
108+
109+
}
110+
111+
TEST_F(ModuleTest, TestGetNonExistMethod) {
112+
Module module(model_path_);
113+
114+
const auto error = module.load_method("forward");
115+
EXPECT_EQ(error, Error::Ok);
116+
117+
// Try to get a method that doesn't exist
118+
auto method_res = module.get_method("backward");
119+
EXPECT_EQ(method_res.error(), Error::NotFound);
120+
}
121+
122+
98123
TEST_F(ModuleTest, TestMethodMeta) {
99124
Module module(model_path_);
100125

0 commit comments

Comments
 (0)