22from .global_var import config
33from .zero_context import ZeroContext
44
5+
56def zero_pre_forward (module , inputs ):
7+ """Helper function for using ZeroContext to gather parmas before forward."""
68 enter = True
79 pipe = False
810 if module ._mode == "PIPE" :
911 enter = module ._micro_idx == 0
1012 pipe = True
1113 if enter :
12- zero_level = module ._zero_level
14+ zero_level = module ._zero_level
1315 forward_flag = 1 if zero_level == 2 else 0
1416 if zero_level == 2 and not module ._need_release :
15- forward_flag = 2 # repeating forward in same layer
16- if module .all_param_no_grad : # only forward
17+ forward_flag = 2 # repeating forward in same layer
18+ if module .all_param_no_grad : # only forward
1719 forward_flag = 0
1820 module ._forward_block_ctx = ZeroContext (module , module ._layer_dict , pipe = pipe )
1921 module ._forward_block_ctx .enter (forward_flag )
2022
23+
2124def zero_post_forward (module , inputs , outputs ):
25+ """Helper function for module _forwar_block_ctx weather exits after forward."""
2226 forward_flag = 1 if module ._zero_level == 2 else 0
2327 if module .all_param_no_grad :
2428 forward_flag = 0
2529 exit = True
2630 if module ._mode == "PIPE" :
27- exit = module ._micro_idx == config [' micros' ] - 1
31+ exit = module ._micro_idx == config [" micros" ] - 1
2832
2933 if exit :
3034 module ._forward_block_ctx .exit (forward_flag )
3135
36+
3237def zero_pre_backward (module , grad_outputs ):
38+ """Helper function for using ZeroContext to init grad buffer before backward."""
3339 backward_flag = 2 if module ._zero_level == 2 else 0
3440 if module ._mode != "PIPE" :
3541 module ._backward_block_ctx = ZeroContext (module , module ._layer_dict )
3642 module ._backward_block_ctx .enter (backward_flag , True )
3743 module .release_next_module (backward_flag )
3844 else :
39- if module ._micro_idx == config ['micros' ] - 1 :
40- module ._backward_block_ctx = ZeroContext (module , module ._layer_dict , pipe = True )
45+ if module ._micro_idx == config ["micros" ] - 1 :
46+ module ._backward_block_ctx = ZeroContext (
47+ module , module ._layer_dict , pipe = True
48+ )
4149 module ._backward_block_ctx .enter (backward_flag , True )
4250
51+
4352def zero_post_backward (module , grad_inputs , grad_outputs ):
53+ """Helper function for module weather release after backward."""
4454 backward_flag = 2 if module ._zero_level == 2 else 0
4555 if module ._mode != "PIPE" :
46- if module ._is_first_layer :
56+ if module ._is_first_layer :
4757 module .release (backward_flag )
4858 else :
4959 if module ._micro_idx == 0 :
5060 module .release (backward_flag )
5161 module ._micro_idx -= 1
5262
63+
5364class OneStepNoGradFunc (torch .autograd .Function ):
5465 """
55- requires_grad = False for all inputs
66+ Requires_grad = False for all inputs.
5667 """
68+
5769 @staticmethod
5870 def forward (ctx , module , placeholder , * x ):
5971 ctx .x = x
@@ -80,7 +92,8 @@ def backward(ctx, grads):
8092 grads = []
8193 for _ in x :
8294 grads .append (None )
83- return None , None , * grads
95+ return None , None , * grads
96+
8497
8598class PreHookFunc (torch .autograd .Function ):
8699 @staticmethod
@@ -94,6 +107,7 @@ def backward(ctx, *grads):
94107 zero_post_backward (ctx .module , grads , None )
95108 return None , * grads
96109
110+
97111class PostHookFunc (torch .autograd .Function ):
98112 @staticmethod
99113 def forward (ctx , module , * out ):
0 commit comments