Skip to content

Commit c4df1cc

Browse files
committed
[shardformer] simplify execute_w_pass_grad_accum & execute_w_pass
1 parent 130b50c commit c4df1cc

File tree

2 files changed

+50
-91
lines changed

2 files changed

+50
-91
lines changed

colossalai/shardformer/layer/_operation.py

Lines changed: 13 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
from colossalai.pipeline.weight_grad_store import WeightGradStore
88

9-
from .utils import is_share_sp_tp
9+
from .utils import (
10+
execute_conv1d_w_pass,
11+
execute_conv1d_w_pass_grad_accum,
12+
execute_w_pass,
13+
execute_w_pass_grad_accum,
14+
is_share_sp_tp,
15+
)
1016

1117
try:
1218
import fused_mix_prec_layer_norm_cuda
@@ -117,18 +123,6 @@ def backward(ctx, grad_output):
117123
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
118124
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
119125

120-
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
121-
if total_input.dtype == torch.float32:
122-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
123-
elif total_input.dtype in (torch.float16, torch.bfloat16):
124-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
125-
else:
126-
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
127-
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)
128-
129-
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
130-
return wgrad_gemm_func(_input_.t(), _grad_output_)
131-
132126
# split dx & dw
133127
if _grad_accum_fusion_available and weight.grad is not None:
134128
grad = weight.grad
@@ -138,7 +132,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
138132
grad_output,
139133
(weight, weight_origin),
140134
functools.partial(
141-
execute_w_pass_grad_accum,
135+
execute_conv1d_w_pass_grad_accum,
142136
),
143137
)
144138
grad_weight = None
@@ -158,7 +152,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
158152
grad_output,
159153
(weight, weight_origin),
160154
functools.partial(
161-
execute_w_pass,
155+
execute_conv1d_w_pass,
162156
wgrad_gemm_func=torch.matmul,
163157
),
164158
)
@@ -197,18 +191,6 @@ def backward(ctx, grad_output):
197191
use_bias = ctx.use_bias
198192
use_zbv = ctx.use_zbv
199193

200-
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
201-
if total_input.dtype == torch.float32:
202-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
203-
elif total_input.dtype in (torch.float16, torch.bfloat16):
204-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
205-
else:
206-
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
207-
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)
208-
209-
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
210-
return wgrad_gemm_func(_input_.t(), _grad_output_)
211-
212194
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
213195
weight_origin = weight
214196
weight = weight.view(weight.shape)
@@ -233,7 +215,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
233215
grad_output,
234216
(weight, weight_origin),
235217
functools.partial(
236-
execute_w_pass_grad_accum,
218+
execute_conv1d_w_pass_grad_accum,
237219
),
238220
)
239221
grad_weight = None
@@ -253,7 +235,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
253235
grad_output,
254236
(weight, weight_origin),
255237
functools.partial(
256-
execute_w_pass,
238+
execute_conv1d_w_pass,
257239
wgrad_gemm_func=torch.matmul,
258240
),
259241
)
@@ -293,18 +275,6 @@ def backward(ctx, grad_output):
293275
fp8_communication = ctx.fp8_communication
294276
use_zbv = ctx.use_zbv
295277

296-
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
297-
if total_input.dtype == torch.float32:
298-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
299-
elif total_input.dtype in (torch.float16, torch.bfloat16):
300-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
301-
else:
302-
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
303-
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
304-
305-
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
306-
return wgrad_gemm_func(_grad_output_.t(), _input_)
307-
308278
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
309279
if use_bias:
310280
bias.view(bias.shape)
@@ -392,18 +362,6 @@ def backward(ctx, grad_output):
392362
use_bias = ctx.use_bias
393363
use_zbv = ctx.use_zbv
394364

395-
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
396-
if total_input.dtype == torch.float32:
397-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
398-
elif total_input.dtype in (torch.float16, torch.bfloat16):
399-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
400-
else:
401-
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
402-
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
403-
404-
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
405-
return wgrad_gemm_func(_grad_output_.t(), _input_)
406-
407365
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
408366
if use_bias:
409367
bias.view(bias.shape)
@@ -641,18 +599,6 @@ def backward(ctx, grad_output):
641599
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
642600
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
643601

644-
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
645-
if total_input.dtype == torch.float32:
646-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
647-
elif total_input.dtype in (torch.float16, torch.bfloat16):
648-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
649-
else:
650-
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
651-
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
652-
653-
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
654-
return wgrad_gemm_func(_grad_output_.t(), _input_)
655-
656602
if _grad_accum_fusion_available and weight.grad is not None:
657603
grad = weight.grad
658604
if use_zbv:
@@ -828,18 +774,6 @@ def backward(ctx, grad_output):
828774
grad_output = grad_output.view(-1, grad_output.shape[-1])
829775
total_input = total_input.reshape(-1, total_input.shape[-1])
830776

831-
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
832-
if total_input.dtype == torch.float32:
833-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
834-
elif total_input.dtype in (torch.float16, torch.bfloat16):
835-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
836-
else:
837-
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
838-
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
839-
840-
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
841-
return wgrad_gemm_func(_grad_output_.t(), _input_)
842-
843777
if _grad_accum_fusion_available and weight.grad is not None:
844778
grad = weight.grad
845779
if use_zbv:
@@ -1000,18 +934,6 @@ def backward(ctx, grad_output):
1000934
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
1001935
# all-reduce scheduled first and have GPU resources allocated
1002936

1003-
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
1004-
if total_input.dtype == torch.float32:
1005-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
1006-
elif total_input.dtype in (torch.float16, torch.bfloat16):
1007-
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
1008-
else:
1009-
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
1010-
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)
1011-
1012-
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
1013-
return wgrad_gemm_func(_input_.t(), _grad_output_)
1014-
1015937
# split dx & dw
1016938
if _grad_accum_fusion_available and weight.grad is not None:
1017939
grad = weight.grad
@@ -1021,7 +943,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
1021943
grad_output,
1022944
(weight, weight_origin),
1023945
functools.partial(
1024-
execute_w_pass_grad_accum,
946+
execute_conv1d_w_pass_grad_accum,
1025947
),
1026948
)
1027949
grad_weight = None
@@ -1041,7 +963,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
1041963
grad_output,
1042964
(weight, weight_origin),
1043965
functools.partial(
1044-
execute_w_pass,
966+
execute_conv1d_w_pass,
1045967
wgrad_gemm_func=torch.matmul,
1046968
),
1047969
)

colossalai/shardformer/layer/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,43 @@
99

1010
from colossalai.accelerator import get_accelerator
1111

12+
try:
13+
import fused_weight_gradient_mlp_cuda
14+
15+
_grad_accum_fusion_available = True
16+
except ImportError:
17+
_grad_accum_fusion_available = False
18+
19+
20+
# execute_w_pass_grad_accum & execute_conv1d_w_pass for GPT2FusedLinearConv1D
21+
def execute_conv1d_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
22+
if _input_.dtype == torch.float32:
23+
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
24+
elif _input_.dtype in (torch.float16, torch.bfloat16):
25+
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
26+
else:
27+
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
28+
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)
29+
30+
31+
def execute_conv1d_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
32+
return wgrad_gemm_func(_input_.t(), _grad_output_)
33+
34+
35+
# execute_w_pass_grad_accum & execute_w_pass for Linear (except GPT2FusedLinearConv1D)
36+
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
37+
if _input_.dtype == torch.float32:
38+
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
39+
elif _input_.dtype in (torch.float16, torch.bfloat16):
40+
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
41+
else:
42+
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
43+
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
44+
45+
46+
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
47+
return wgrad_gemm_func(_grad_output_.t(), _input_)
48+
1249

1350
class SeqParallelUtils:
1451
@staticmethod

0 commit comments

Comments
 (0)