Skip to content

Commit f51b1d5

Browse files
committed
update
1 parent 83d11cc commit f51b1d5

File tree

10 files changed

+493
-0
lines changed

10 files changed

+493
-0
lines changed

infini_train/include/autograd/accumulate.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class AccumulateGrad final : public Function {
1818

1919
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &) override;
2020

21+
std::shared_ptr<Tensor> tensor() const { return tensor_; }
22+
2123
private:
2224
std::shared_ptr<Tensor> tensor_ = nullptr;
2325
float learning_rate_ = 1.0f;

infini_train/include/autograd/function.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <functional>
34
#include <memory>
45
#include <utility>
56
#include <vector>
@@ -9,6 +10,14 @@ class Tensor;
910
}
1011

1112
namespace infini_train::autograd {
13+
class HookHandle;
14+
using FunctionForwardPreHook = std::function<void(class Function*, const std::vector<std::shared_ptr<Tensor>>&)>;
15+
using FunctionForwardPostHook = std::function<void(class Function*, const std::vector<std::shared_ptr<Tensor>>&,
16+
const std::vector<std::shared_ptr<Tensor>>&)>;
17+
using FunctionBackwardPreHook = std::function<void(class Function*, const std::vector<std::shared_ptr<Tensor>>&)>;
18+
using FunctionBackwardPostHook = std::function<void(class Function*, const std::vector<std::shared_ptr<Tensor>>&,
19+
const std::vector<std::shared_ptr<Tensor>>&)>;
20+
1221
class Function : public std::enable_shared_from_this<Function> {
1322
public:
1423
static constexpr char kUndefinedType[] = "Undefined";
@@ -28,6 +37,11 @@ class Function : public std::enable_shared_from_this<Function> {
2837

2938
void IncreaseDependenciesNumber();
3039

40+
std::shared_ptr<HookHandle> RegisterForwardPreHook(FunctionForwardPreHook hook);
41+
std::shared_ptr<HookHandle> RegisterForwardPostHook(FunctionForwardPostHook hook);
42+
std::shared_ptr<HookHandle> RegisterBackwardPreHook(FunctionBackwardPreHook hook);
43+
std::shared_ptr<HookHandle> RegisterBackwardPostHook(FunctionBackwardPostHook hook);
44+
3145
protected:
3246
std::vector<std::shared_ptr<Tensor>> saved_tensors_;
3347

@@ -38,5 +52,9 @@ class Function : public std::enable_shared_from_this<Function> {
3852
int grad_outputs_reached_ = 0;
3953
std::vector<std::shared_ptr<Tensor>> grad_outputs_;
4054
const std::string type_ = kUndefinedType;
55+
std::vector<FunctionForwardPreHook> forward_pre_hooks_;
56+
std::vector<FunctionForwardPostHook> forward_post_hooks_;
57+
std::vector<FunctionBackwardPreHook> backward_pre_hooks_;
58+
std::vector<FunctionBackwardPostHook> backward_post_hooks_;
4159
};
4260
} // namespace infini_train::autograd

infini_train/include/autograd/function_hook.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

3+
#include <functional>
34
#include <memory>
5+
#include <vector>
46

57
#include "infini_train/include/nn/parallel/reduce_op_type.h"
68

@@ -13,6 +15,8 @@ class ProcessGroup;
1315
} // namespace infini_train
1416

1517
namespace infini_train::autograd {
18+
class Function;
19+
class HookHandle;
1620
class PostAccumulateGradHook {
1721
public:
1822
virtual void operator()(const std::shared_ptr<Tensor> &tensor) = 0;
@@ -30,4 +34,36 @@ class AllReducePostAccumulateHook : public PostAccumulateGradHook {
3034
infini_train::nn::parallel::function::ReduceOpType reduce_op_;
3135
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
3236
};
37+
38+
// Forward pre-hook: called before forward pass
39+
using FunctionForwardPreHook = std::function<void(Function*, const std::vector<std::shared_ptr<Tensor>>&)>;
40+
41+
// Forward post-hook: called after forward pass
42+
using FunctionForwardPostHook = std::function<void(Function*, const std::vector<std::shared_ptr<Tensor>>&,
43+
const std::vector<std::shared_ptr<Tensor>>&)>;
44+
45+
// Backward pre-hook: called before backward pass
46+
using FunctionBackwardPreHook = std::function<void(Function*, const std::vector<std::shared_ptr<Tensor>>&)>;
47+
48+
// Backward post-hook: called after backward pass
49+
using FunctionBackwardPostHook = std::function<void(Function*, const std::vector<std::shared_ptr<Tensor>>&,
50+
const std::vector<std::shared_ptr<Tensor>>&)>;
51+
52+
template <typename HookType>
53+
class FunctionHookHandleImpl : public HookHandle {
54+
public:
55+
FunctionHookHandleImpl(std::vector<HookType>* hooks, size_t id) : hooks_(hooks), id_(id) {}
56+
57+
void Remove() override {
58+
if (!removed_ && hooks_ && id_ < hooks_->size()) {
59+
(*hooks_)[id_] = nullptr;
60+
removed_ = true;
61+
}
62+
}
63+
64+
private:
65+
std::vector<HookType>* hooks_;
66+
size_t id_;
67+
bool removed_ = false;
68+
};
3369
} // namespace infini_train::autograd
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include <functional>
4+
#include <memory>
5+
#include <vector>
6+
7+
namespace infini_train {
8+
class Tensor;
9+
10+
namespace autograd {
11+
12+
// Hook handle for removing hooks
13+
class HookHandle {
14+
public:
15+
virtual ~HookHandle() = default;
16+
virtual void Remove() = 0;
17+
};
18+
19+
// Tensor backward hook: modifies gradient during backward pass
20+
// Returns modified gradient or nullptr to keep original
21+
using TensorBackwardHook = std::function<std::shared_ptr<Tensor>(const std::shared_ptr<Tensor>&)>;
22+
23+
class TensorBackwardHookHandle : public HookHandle {
24+
public:
25+
TensorBackwardHookHandle(std::vector<TensorBackwardHook>* hooks, size_t id)
26+
: hooks_(hooks), id_(id) {}
27+
28+
void Remove() override;
29+
30+
private:
31+
std::vector<TensorBackwardHook>* hooks_;
32+
size_t id_;
33+
bool removed_ = false;
34+
};
35+
36+
} // namespace autograd
37+
} // namespace infini_train
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#pragma once
2+
3+
#include <functional>
4+
#include <memory>
5+
#include <vector>
6+
7+
namespace infini_train {
8+
class Tensor;
9+
10+
namespace nn {
11+
class Module;
12+
13+
// Forward pre-hook: called before forward pass
14+
// Args: (module, input_tensors)
15+
using ForwardPreHook = std::function<void(Module*, const std::vector<std::shared_ptr<Tensor>>&)>;
16+
17+
// Forward post-hook: called after forward pass
18+
// Args: (module, input_tensors, output_tensors)
19+
using ForwardPostHook = std::function<void(Module*, const std::vector<std::shared_ptr<Tensor>>&,
20+
const std::vector<std::shared_ptr<Tensor>>&)>;
21+
22+
// Backward pre-hook: called before backward pass
23+
// Args: (module, grad_output)
24+
using BackwardPreHook = std::function<void(Module*, const std::vector<std::shared_ptr<Tensor>>&)>;
25+
26+
// Backward post-hook: called after backward pass
27+
// Args: (module, grad_input, grad_output)
28+
using BackwardPostHook = std::function<void(Module*, const std::vector<std::shared_ptr<Tensor>>&,
29+
const std::vector<std::shared_ptr<Tensor>>&)>;
30+
31+
class ModuleHookHandle {
32+
public:
33+
virtual ~ModuleHookHandle() = default;
34+
virtual void Remove() = 0;
35+
};
36+
37+
template <typename HookType>
38+
class ModuleHookHandleImpl : public ModuleHookHandle {
39+
public:
40+
ModuleHookHandleImpl(std::vector<HookType>* hooks, size_t id) : hooks_(hooks), id_(id) {}
41+
42+
void Remove() override {
43+
if (!removed_ && hooks_ && id_ < hooks_->size()) {
44+
(*hooks_)[id_] = nullptr;
45+
removed_ = true;
46+
}
47+
}
48+
49+
private:
50+
std::vector<HookType>* hooks_;
51+
size_t id_;
52+
bool removed_ = false;
53+
};
54+
55+
} // namespace nn
56+
} // namespace infini_train

infini_train/include/nn/modules/module.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <vector>
88

99
#include "infini_train/include/datatype.h"
10+
#include "infini_train/include/nn/module_hook.h"
1011

1112
namespace infini_train {
1213
class Tensor;
@@ -50,6 +51,10 @@ class Module : public std::enable_shared_from_this<Module> {
5051

5152
std::unordered_map<std::string, std::shared_ptr<Tensor>> StateDict() const;
5253

54+
// operator() calls hooks and Forward
55+
std::vector<std::shared_ptr<Tensor>> operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
56+
57+
// Forward to be overridden by subclasses
5358
virtual std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
5459

5560
virtual float TrainStep(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
@@ -66,13 +71,24 @@ class Module : public std::enable_shared_from_this<Module> {
6671

6772
virtual std::shared_ptr<Module> ReplicateForDataParallel(int device_idx) const;
6873

74+
// Hook registration methods
75+
std::shared_ptr<ModuleHookHandle> RegisterForwardPreHook(ForwardPreHook hook);
76+
std::shared_ptr<ModuleHookHandle> RegisterForwardPostHook(ForwardPostHook hook);
77+
std::shared_ptr<ModuleHookHandle> RegisterBackwardPreHook(BackwardPreHook hook);
78+
std::shared_ptr<ModuleHookHandle> RegisterBackwardPostHook(BackwardPostHook hook);
79+
6980
protected:
7081
const Device *device_ = nullptr;
7182
const std::string type_ = kUndefinedType;
7283
std::unordered_map<std::string, std::shared_ptr<Module>> modules_;
7384
std::unordered_map<std::string, std::shared_ptr<Tensor>> parameters_;
7485
std::unordered_map<std::string, std::shared_ptr<Tensor>> buffers_;
7586

87+
std::vector<ForwardPreHook> forward_pre_hooks_;
88+
std::vector<ForwardPostHook> forward_post_hooks_;
89+
std::vector<BackwardPreHook> backward_pre_hooks_;
90+
std::vector<BackwardPostHook> backward_post_hooks_;
91+
7692
private:
7793
std::unordered_map<std::string, std::shared_ptr<Module>>
7894
NamedModules(const std::string &prefix = "", bool remove_duplicate = true,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
#include <vector>
6+
7+
namespace infini_train {
8+
class Tensor;
9+
10+
namespace autograd {
11+
class Function;
12+
class HookHandle;
13+
} // namespace autograd
14+
15+
namespace nn {
16+
class Module;
17+
} // namespace nn
18+
19+
namespace utils {
20+
21+
class PrecisionChecker {
22+
public:
23+
struct Config {
24+
bool check_nan = true;
25+
bool check_inf = true;
26+
bool print_stats = true;
27+
bool abort_on_error = false;
28+
};
29+
30+
static void RegisterForFunction(autograd::Function* func, const std::string& name = "",
31+
const Config& config = Config());
32+
33+
static void RegisterForAllFunctions(const std::vector<std::shared_ptr<autograd::Function>>& functions,
34+
const Config& config = Config());
35+
36+
// Register hooks for a Module (checks forward inputs/outputs)
37+
static void RegisterForModule(nn::Module* module, const std::string& name = "",
38+
const Config& config = Config());
39+
40+
private:
41+
static void CheckTensors(const std::string& stage, const std::string& name,
42+
const std::vector<std::shared_ptr<Tensor>>& tensors,
43+
const Config& config);
44+
};
45+
46+
} // namespace utils
47+
} // namespace infini_train

0 commit comments

Comments
 (0)