Skip to content

Commit 3d35a26

Browse files
authored
Merge pull request #192 from CarryFun/fanruikai/doc_update
[WIP] Update doc and notes for BMTrain.
2 parents 22a42af + 3d7d7d9 commit 3d35a26

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+2590
-1106
lines changed

bmtrain/block_layer.py

Lines changed: 207 additions & 108 deletions
Large diffs are not rendered by default.

bmtrain/hook_func.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,70 @@
22
from .global_var import config
33
from .zero_context import ZeroContext
44

5+
56
def zero_pre_forward(module, inputs):
7+
"""Helper function for using ZeroContext to gather parmas before forward."""
68
enter = True
79
pipe = False
810
if module._mode == "PIPE":
911
enter = module._micro_idx == 0
1012
pipe = True
1113
if enter:
12-
zero_level = module._zero_level
14+
zero_level = module._zero_level
1315
forward_flag = 1 if zero_level == 2 else 0
1416
if zero_level == 2 and not module._need_release:
15-
forward_flag = 2 # repeating forward in same layer
16-
if module.all_param_no_grad: #only forward
17+
forward_flag = 2 # repeating forward in same layer
18+
if module.all_param_no_grad: # only forward
1719
forward_flag = 0
1820
module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe)
1921
module._forward_block_ctx.enter(forward_flag)
2022

23+
2124
def zero_post_forward(module, inputs, outputs):
25+
"""Helper function for module _forwar_block_ctx weather exits after forward."""
2226
forward_flag = 1 if module._zero_level == 2 else 0
2327
if module.all_param_no_grad:
2428
forward_flag = 0
2529
exit = True
2630
if module._mode == "PIPE":
27-
exit = module._micro_idx == config['micros'] - 1
31+
exit = module._micro_idx == config["micros"] - 1
2832

2933
if exit:
3034
module._forward_block_ctx.exit(forward_flag)
3135

36+
3237
def zero_pre_backward(module, grad_outputs):
38+
"""Helper function for using ZeroContext to init grad buffer before backward."""
3339
backward_flag = 2 if module._zero_level == 2 else 0
3440
if module._mode != "PIPE":
3541
module._backward_block_ctx = ZeroContext(module, module._layer_dict)
3642
module._backward_block_ctx.enter(backward_flag, True)
3743
module.release_next_module(backward_flag)
3844
else:
39-
if module._micro_idx == config['micros'] - 1:
40-
module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True)
45+
if module._micro_idx == config["micros"] - 1:
46+
module._backward_block_ctx = ZeroContext(
47+
module, module._layer_dict, pipe=True
48+
)
4149
module._backward_block_ctx.enter(backward_flag, True)
4250

51+
4352
def zero_post_backward(module, grad_inputs, grad_outputs):
53+
"""Helper function for module weather release after backward."""
4454
backward_flag = 2 if module._zero_level == 2 else 0
4555
if module._mode != "PIPE":
46-
if module._is_first_layer:
56+
if module._is_first_layer:
4757
module.release(backward_flag)
4858
else:
4959
if module._micro_idx == 0:
5060
module.release(backward_flag)
5161
module._micro_idx -= 1
5262

63+
5364
class OneStepNoGradFunc(torch.autograd.Function):
5465
"""
55-
requires_grad = False for all inputs
66+
Requires_grad = False for all inputs.
5667
"""
68+
5769
@staticmethod
5870
def forward(ctx, module, placeholder, *x):
5971
ctx.x = x
@@ -80,7 +92,8 @@ def backward(ctx, grads):
8092
grads = []
8193
for _ in x:
8294
grads.append(None)
83-
return None, None, *grads
95+
return None, None, *grads
96+
8497

8598
class PreHookFunc(torch.autograd.Function):
8699
@staticmethod
@@ -94,6 +107,7 @@ def backward(ctx, *grads):
94107
zero_post_backward(ctx.module, grads, None)
95108
return None, *grads
96109

110+
97111
class PostHookFunc(torch.autograd.Function):
98112
@staticmethod
99113
def forward(ctx, module, *out):

0 commit comments

Comments
 (0)