77#include < vector>
88
99#include " infini_train/include/datatype.h"
10+ #include " infini_train/include/nn/module_hook.h"
1011
1112namespace infini_train {
1213class Tensor ;
@@ -46,6 +47,10 @@ class Module : public std::enable_shared_from_this<Module> {
4647
4748 std::unordered_map<std::string, std::shared_ptr<Tensor>> StateDict () const ;
4849
50+ // operator() calls hooks and Forward
51+ std::vector<std::shared_ptr<Tensor>> operator ()(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
52+
53+ // Forward to be overridden by subclasses
4954 virtual std::vector<std::shared_ptr<Tensor>> Forward (const std::vector<std::shared_ptr<Tensor>> &input_tensors);
5055
5156 virtual float TrainStep (const std::vector<std::shared_ptr<Tensor>> &input_tensors,
@@ -62,13 +67,24 @@ class Module : public std::enable_shared_from_this<Module> {
6267
6368 virtual std::shared_ptr<Module> ReplicateForDataParallel (int device_idx) const ;
6469
70+ // Hook registration methods
71+ std::shared_ptr<ModuleHookHandle> RegisterForwardPreHook (ForwardPreHook hook);
72+ std::shared_ptr<ModuleHookHandle> RegisterForwardPostHook (ForwardPostHook hook);
73+ std::shared_ptr<ModuleHookHandle> RegisterBackwardPreHook (BackwardPreHook hook);
74+ std::shared_ptr<ModuleHookHandle> RegisterBackwardPostHook (BackwardPostHook hook);
75+
6576protected:
6677 const Device *device_ = nullptr ;
6778 const std::string type_ = kUndefinedType ;
6879 std::unordered_map<std::string, std::shared_ptr<Module>> modules_;
6980 std::unordered_map<std::string, std::shared_ptr<Tensor>> parameters_;
7081 std::unordered_map<std::string, std::shared_ptr<Tensor>> buffers_;
7182
83+ std::vector<ForwardPreHook> forward_pre_hooks_;
84+ std::vector<ForwardPostHook> forward_post_hooks_;
85+ std::vector<BackwardPreHook> backward_pre_hooks_;
86+ std::vector<BackwardPostHook> backward_post_hooks_;
87+
7288private:
7389 std::unordered_map<std::string, std::shared_ptr<Module>>
7490 NamedModules (const std::string &prefix = " " , bool remove_duplicate = true ,
0 commit comments