Skip to content

Commit 76835e8

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Expose a method getter api in Module (pytorch#11929)
Summary: Pull Request resolved: pytorch#11929 It is sometimes useful for modules to be passed around in user space for user convenience, but for utilities those users interact with to operate on the methods directly. This is a bit of a layering violation but it allows us to define narrower scoped helpers so I think its worth the risk that the user finds a way to trash the internal method object. In general method should be fairly robust to being trashed anyway since itself is a blessed front end api from the team. Differential Revision: D77248983 Reviewed By: larryliu0820
1 parent dd4488d commit 76835e8

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

extension/module/module.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,16 @@ runtime::Error Module::load_method(
228228
return runtime::Error::Ok;
229229
}
230230

231+
ET_NODISCARD inline runtime::Result<Method*> Module::method(
232+
const std::string& method_name) {
233+
ET_CHECK_OR_RETURN_ERROR(
234+
methods_.count(method_name) > 0,
235+
InvalidArgument,
236+
"no such method in program: %s",
237+
method_name.c_str());
238+
return methods_[method_name].method.get();
239+
}
240+
231241
runtime::Result<MethodMeta> Module::method_meta(
232242
const std::string& method_name) {
233243
ET_CHECK_OK_OR_RETURN_ERROR(load());

extension/module/module.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@ class Module {
194194
return load_method(method_name, nullptr, event_tracer);
195195
}
196196

197+
/**
198+
* Get a method by it's name. Not recommended to use this method directly as
199+
* an end user. It's exposed to allow for composability of module in apis that
200+
* operate on method.
201+
*
202+
* @param[in] method_name The name of the method to get.
203+
*
204+
* @returns A Result object containing either a pointer to the requested
205+
* method or an error to indicate failure.
206+
*/
207+
ET_NODISCARD inline runtime::Result<Method*> method(
208+
const std::string& method_name);
209+
197210
/**
198211
* Load the 'forward' method from the program and set up memory management if
199212
* needed. The loaded method is cached to reuse the next time it's executed.

0 commit comments

Comments
 (0)