diff --git a/extension/module/module.h b/extension/module/module.h index 3189da2aaa6..f7c9b1c8c56 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -44,6 +44,7 @@ class Module { * * @param[in] file_path The path to the ExecuTorch program file to load. * @param[in] load_mode The loading mode to use. + * @param[in] event_tracer A EventTracer used for tracking and logging events. */ explicit Module( const std::string& file_path, @@ -132,13 +133,28 @@ class Module { * needed. The loaded method is cached to reuse the next time it's executed. * * @param[in] method_name The name of the method to load. + * @param[in] event_tracer A EventTracer used for tracking and logging events. * * @returns An Error to indicate success or failure. */ ET_NODISCARD runtime::Error load_method( const std::string& method_name, - torch::executor::EventTracer* tracer = nullptr); + torch::executor::EventTracer* event_tracer = nullptr); + + /** + * Load the 'forward' method from the program and set up memory management if + * needed. The loaded method is cached to reuse the next time it's executed. + * + * @param[in] event_tracer An event tracer used for tracking and logging + * events. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD inline runtime::Error load_forward( + torch::executor::EventTracer* event_tracer = nullptr) { + return load_method("forward", event_tracer); + } /** * Checks if a specific method is loaded.