77#include < vector>
88
99#include " infini_train/include/datatype.h"
10+ #include " infini_train/include/nn/module_hook.h"
1011
1112namespace infini_train {
1213class Tensor ;
@@ -50,6 +51,10 @@ class Module : public std::enable_shared_from_this<Module> {
5051
5152 std::unordered_map<std::string, std::shared_ptr<Tensor>> StateDict () const ;
5253
54+ // operator() calls hooks and Forward
55+ std::vector<std::shared_ptr<Tensor>> operator ()(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
56+
57+ // Forward to be overridden by subclasses
5358 virtual std::vector<std::shared_ptr<Tensor>> Forward (const std::vector<std::shared_ptr<Tensor>> &input_tensors);
5459
5560 virtual float TrainStep (const std::vector<std::shared_ptr<Tensor>> &input_tensors,
@@ -66,13 +71,24 @@ class Module : public std::enable_shared_from_this<Module> {
6671
6772 virtual std::shared_ptr<Module> ReplicateForDataParallel (int device_idx) const ;
6873
74+ // Hook registration methods
75+ std::shared_ptr<ModuleHookHandle> RegisterForwardPreHook (ForwardPreHook hook);
76+ std::shared_ptr<ModuleHookHandle> RegisterForwardPostHook (ForwardPostHook hook);
77+ std::shared_ptr<ModuleHookHandle> RegisterBackwardPreHook (BackwardPreHook hook);
78+ std::shared_ptr<ModuleHookHandle> RegisterBackwardPostHook (BackwardPostHook hook);
79+
6980protected:
7081 const Device *device_ = nullptr ;
7182 const std::string type_ = kUndefinedType ;
7283 std::unordered_map<std::string, std::shared_ptr<Module>> modules_;
7384 std::unordered_map<std::string, std::shared_ptr<Tensor>> parameters_;
7485 std::unordered_map<std::string, std::shared_ptr<Tensor>> buffers_;
7586
87+ std::vector<ForwardPreHook> forward_pre_hooks_;
88+ std::vector<ForwardPostHook> forward_post_hooks_;
89+ std::vector<BackwardPreHook> backward_pre_hooks_;
90+ std::vector<BackwardPostHook> backward_post_hooks_;
91+
7692private:
7793 std::unordered_map<std::string, std::shared_ptr<Module>>
7894 NamedModules (const std::string &prefix = " " , bool remove_duplicate = true ,
0 commit comments