Skip to content

Commit daf2a80

Browse files
author
zhenyanzhang
committed
[ExecuTorch][#9638] Introduce Protected Method Getter in Extension.Module
# Context This issue is a step of #9638. In the discussion, we want to unblock having `extension/Module` as the single source of implementation, which means that `pybindings/PyModule` should use `extension/Module` rather than its own. Although we are decouping method getter from `pybindings` implementation, method getter itself is still needed. To keep having the method getter while not exposing it, we can create a protected method getter and confine it's usage inside child classes that we are about to create. # Proposal Add a protected `get_method` to `extension.Module`, taking method name string as an input. Differential Revision: [D73473766](https://our.internmc.facebook.com/intern/diff/D73473766/) [ghstack-poisoned]
1 parent 95c663e commit daf2a80

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

extension/module/module.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,5 +302,15 @@ runtime::Error Module::set_output(
302302
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
303303
}
304304

305+
ET_NODISCARD inline runtime::Result<Method*> Module::get_method(
306+
const std::string& method_name) {
307+
ET_CHECK_OR_RETURN_ERROR(
308+
methods_.count(method_name) > 0,
309+
InvalidArgument,
310+
"no such method in program: %s",
311+
method_name.c_str());
312+
return methods_[method_name].method.get();
313+
}
314+
305315
} // namespace extension
306316
} // namespace executorch

extension/module/module.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,16 @@ class Module {
493493
std::unique_ptr<NamedDataMap> data_map_;
494494

495495
protected:
496+
/**
497+
* Get a method by method name.
498+
*
499+
* @param[in] method_name The name of the method to get.
500+
*
501+
* @returns A Result object containing either a pointer to the requested
502+
* method or an error to indicate failure.
503+
*/
504+
ET_NODISCARD inline runtime::Result<Method*> get_method(
505+
const std::string& method_name);
496506
std::unordered_map<std::string, MethodHolder> methods_;
497507

498508
friend class ExecuTorchJni;

0 commit comments

Comments
 (0)