Skip to content

Commit b14dea8

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Add convenience load methond for forward. (#5446)
Summary: Pull Request resolved: #5446 People write `module->load_method("forward")` too often, let's simplify that a bit. Reviewed By: kirklandsign Differential Revision: D62906055 fbshipit-source-id: d45934a27c61fd0ea644e6b58a21116dbc02fa17
1 parent b7dfd8a commit b14dea8

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

extension/module/module.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Module {
4444
*
4545
* @param[in] file_path The path to the ExecuTorch program file to load.
4646
* @param[in] load_mode The loading mode to use.
47+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
4748
*/
4849
explicit Module(
4950
const std::string& file_path,
@@ -132,13 +133,28 @@ class Module {
132133
* needed. The loaded method is cached to reuse the next time it's executed.
133134
*
134135
* @param[in] method_name The name of the method to load.
136+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
135137
*
136138
* @returns An Error to indicate success or failure.
137139
*/
138140
ET_NODISCARD
139141
runtime::Error load_method(
140142
const std::string& method_name,
141-
torch::executor::EventTracer* tracer = nullptr);
143+
torch::executor::EventTracer* event_tracer = nullptr);
144+
145+
/**
146+
* Load the 'forward' method from the program and set up memory management if
147+
* needed. The loaded method is cached to reuse the next time it's executed.
148+
*
149+
* @param[in] event_tracer An event tracer used for tracking and logging
150+
* events.
151+
*
152+
* @returns An Error to indicate success or failure.
153+
*/
154+
ET_NODISCARD inline runtime::Error load_forward(
155+
torch::executor::EventTracer* event_tracer = nullptr) {
156+
return load_method("forward", event_tracer);
157+
}
142158

143159
/**
144160
* Checks if a specific method is loaded.

0 commit comments

Comments
 (0)