66
77from colossalai .pipeline .weight_grad_store import WeightGradStore
88
9- from .utils import is_share_sp_tp
9+ from .utils import (
10+ execute_conv1d_w_pass ,
11+ execute_conv1d_w_pass_grad_accum ,
12+ execute_w_pass ,
13+ execute_w_pass_grad_accum ,
14+ is_share_sp_tp ,
15+ )
1016
1117try :
1218 import fused_mix_prec_layer_norm_cuda
@@ -117,18 +123,6 @@ def backward(ctx, grad_output):
117123 # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
118124 # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
119125
120- def execute_w_pass_grad_accum (_input_ , _grad_output_ , _weight_main_grad_ ):
121- if total_input .dtype == torch .float32 :
122- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32
123- elif total_input .dtype in (torch .float16 , torch .bfloat16 ):
124- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16
125- else :
126- raise RuntimeError ("Unsupported gradient type for gradient accumulation fusion" )
127- wgrad_gemm_accum_func (_grad_output_ , _input_ , _weight_main_grad_ )
128-
129- def execute_w_pass (_input_ , _grad_output_ , _weight_main_grad_ = None , wgrad_gemm_func = None ):
130- return wgrad_gemm_func (_input_ .t (), _grad_output_ )
131-
132126 # split dx & dw
133127 if _grad_accum_fusion_available and weight .grad is not None :
134128 grad = weight .grad
@@ -138,7 +132,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
138132 grad_output ,
139133 (weight , weight_origin ),
140134 functools .partial (
141- execute_w_pass_grad_accum ,
135+ execute_conv1d_w_pass_grad_accum ,
142136 ),
143137 )
144138 grad_weight = None
@@ -158,7 +152,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
158152 grad_output ,
159153 (weight , weight_origin ),
160154 functools .partial (
161- execute_w_pass ,
155+ execute_conv1d_w_pass ,
162156 wgrad_gemm_func = torch .matmul ,
163157 ),
164158 )
@@ -197,18 +191,6 @@ def backward(ctx, grad_output):
197191 use_bias = ctx .use_bias
198192 use_zbv = ctx .use_zbv
199193
200- def execute_w_pass_grad_accum (_input_ , _grad_output_ , _weight_main_grad_ ):
201- if total_input .dtype == torch .float32 :
202- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32
203- elif total_input .dtype in (torch .float16 , torch .bfloat16 ):
204- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16
205- else :
206- raise RuntimeError ("Unsupported gradient type for gradient accumulation fusion" )
207- wgrad_gemm_accum_func (_grad_output_ , _input_ , _weight_main_grad_ )
208-
209- def execute_w_pass (_input_ , _grad_output_ , _weight_main_grad_ = None , wgrad_gemm_func = None ):
210- return wgrad_gemm_func (_input_ .t (), _grad_output_ )
211-
212194 # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
213195 weight_origin = weight
214196 weight = weight .view (weight .shape )
@@ -233,7 +215,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
233215 grad_output ,
234216 (weight , weight_origin ),
235217 functools .partial (
236- execute_w_pass_grad_accum ,
218+ execute_conv1d_w_pass_grad_accum ,
237219 ),
238220 )
239221 grad_weight = None
@@ -253,7 +235,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
253235 grad_output ,
254236 (weight , weight_origin ),
255237 functools .partial (
256- execute_w_pass ,
238+ execute_conv1d_w_pass ,
257239 wgrad_gemm_func = torch .matmul ,
258240 ),
259241 )
@@ -293,18 +275,6 @@ def backward(ctx, grad_output):
293275 fp8_communication = ctx .fp8_communication
294276 use_zbv = ctx .use_zbv
295277
296- def execute_w_pass_grad_accum (_input_ , _grad_output_ , _weight_main_grad_ ):
297- if total_input .dtype == torch .float32 :
298- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32
299- elif total_input .dtype in (torch .float16 , torch .bfloat16 ):
300- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16
301- else :
302- raise RuntimeError ("Unsupported gradient type for gradient accumulation fusion" )
303- wgrad_gemm_accum_func (_input_ , _grad_output_ , _weight_main_grad_ )
304-
305- def execute_w_pass (_input_ , _grad_output_ , _weight_main_grad_ = None , wgrad_gemm_func = None ):
306- return wgrad_gemm_func (_grad_output_ .t (), _input_ )
307-
308278 # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
309279 if use_bias :
310280 bias .view (bias .shape )
@@ -392,18 +362,6 @@ def backward(ctx, grad_output):
392362 use_bias = ctx .use_bias
393363 use_zbv = ctx .use_zbv
394364
395- def execute_w_pass_grad_accum (_input_ , _grad_output_ , _weight_main_grad_ ):
396- if total_input .dtype == torch .float32 :
397- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32
398- elif total_input .dtype in (torch .float16 , torch .bfloat16 ):
399- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16
400- else :
401- raise RuntimeError ("Unsupported gradient type for gradient accumulation fusion" )
402- wgrad_gemm_accum_func (_input_ , _grad_output_ , _weight_main_grad_ )
403-
404- def execute_w_pass (_input_ , _grad_output_ , _weight_main_grad_ = None , wgrad_gemm_func = None ):
405- return wgrad_gemm_func (_grad_output_ .t (), _input_ )
406-
407365 # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
408366 if use_bias :
409367 bias .view (bias .shape )
@@ -641,18 +599,6 @@ def backward(ctx, grad_output):
641599 # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
642600 # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
643601
644- def execute_w_pass_grad_accum (_input_ , _grad_output_ , _weight_main_grad_ ):
645- if total_input .dtype == torch .float32 :
646- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32
647- elif total_input .dtype in (torch .float16 , torch .bfloat16 ):
648- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16
649- else :
650- raise RuntimeError ("Unsupported gradient type for gradient accumulation fusion" )
651- wgrad_gemm_accum_func (_input_ , _grad_output_ , _weight_main_grad_ )
652-
653- def execute_w_pass (_input_ , _grad_output_ , _weight_main_grad_ = None , wgrad_gemm_func = None ):
654- return wgrad_gemm_func (_grad_output_ .t (), _input_ )
655-
656602 if _grad_accum_fusion_available and weight .grad is not None :
657603 grad = weight .grad
658604 if use_zbv :
@@ -828,18 +774,6 @@ def backward(ctx, grad_output):
828774 grad_output = grad_output .view (- 1 , grad_output .shape [- 1 ])
829775 total_input = total_input .reshape (- 1 , total_input .shape [- 1 ])
830776
831- def execute_w_pass_grad_accum (_input_ , _grad_output_ , _weight_main_grad_ ):
832- if total_input .dtype == torch .float32 :
833- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32
834- elif total_input .dtype in (torch .float16 , torch .bfloat16 ):
835- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16
836- else :
837- raise RuntimeError ("Unsupported gradient type for gradient accumulation fusion" )
838- wgrad_gemm_accum_func (_input_ , _grad_output_ , _weight_main_grad_ )
839-
840- def execute_w_pass (_input_ , _grad_output_ , _weight_main_grad_ = None , wgrad_gemm_func = None ):
841- return wgrad_gemm_func (_grad_output_ .t (), _input_ )
842-
843777 if _grad_accum_fusion_available and weight .grad is not None :
844778 grad = weight .grad
845779 if use_zbv :
@@ -1000,18 +934,6 @@ def backward(ctx, grad_output):
1000934 # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
1001935 # all-reduce scheduled first and have GPU resources allocated
1002936
1003- def execute_w_pass_grad_accum (_input_ , _grad_output_ , _weight_main_grad_ ):
1004- if total_input .dtype == torch .float32 :
1005- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32
1006- elif total_input .dtype in (torch .float16 , torch .bfloat16 ):
1007- wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16
1008- else :
1009- raise RuntimeError ("Unsupported gradient type for gradient accumulation fusion" )
1010- wgrad_gemm_accum_func (_grad_output_ , _input_ , _weight_main_grad_ )
1011-
1012- def execute_w_pass (_input_ , _grad_output_ , _weight_main_grad_ = None , wgrad_gemm_func = None ):
1013- return wgrad_gemm_func (_input_ .t (), _grad_output_ )
1014-
1015937 # split dx & dw
1016938 if _grad_accum_fusion_available and weight .grad is not None :
1017939 grad = weight .grad
@@ -1021,7 +943,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
1021943 grad_output ,
1022944 (weight , weight_origin ),
1023945 functools .partial (
1024- execute_w_pass_grad_accum ,
946+ execute_conv1d_w_pass_grad_accum ,
1025947 ),
1026948 )
1027949 grad_weight = None
@@ -1041,7 +963,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
1041963 grad_output ,
1042964 (weight , weight_origin ),
1043965 functools .partial (
1044- execute_w_pass ,
966+ execute_conv1d_w_pass ,
1045967 wgrad_gemm_func = torch .matmul ,
1046968 ),
1047969 )
0 commit comments