Skip to content

Commit 130b50c

Browse files
committed
[fix] rm useless comments
1 parent c0b6fbc commit 130b50c

File tree

1 file changed

+0
-168
lines changed

1 file changed

+0
-168
lines changed

colossalai/shardformer/layer/_operation.py

Lines changed: 0 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -142,30 +142,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
142142
),
143143
)
144144
grad_weight = None
145-
# if grad.dtype == torch.float32:
146-
# WeightGradStore.put(
147-
# total_input,
148-
# grad_output,
149-
# (weight, weight_origin),
150-
# functools.partial(
151-
# execute_w_pass_grad_accum,
152-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
153-
# ),
154-
# )
155-
# grad_weight = None
156-
# elif grad.dtype in (torch.float16, torch.bfloat16):
157-
# WeightGradStore.put(
158-
# total_input,
159-
# grad_output,
160-
# (weight, weight_origin),
161-
# functools.partial(
162-
# execute_w_pass_grad_accum,
163-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
164-
# ),
165-
# )
166-
# grad_weight = None
167-
# else:
168-
# raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
169145
else:
170146
if grad.dtype == torch.float32:
171147
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
@@ -261,30 +237,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
261237
),
262238
)
263239
grad_weight = None
264-
# if grad.dtype == torch.float32:
265-
# WeightGradStore.put(
266-
# total_input,
267-
# grad_output,
268-
# (weight, weight_origin),
269-
# functools.partial(
270-
# execute_w_pass_grad_accum,
271-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
272-
# ),
273-
# )
274-
# grad_weight = None
275-
# elif grad.dtype in (torch.float16, torch.bfloat16):
276-
# WeightGradStore.put(
277-
# total_input,
278-
# grad_output,
279-
# (weight, weight_origin),
280-
# functools.partial(
281-
# execute_w_pass_grad_accum,
282-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
283-
# ),
284-
# )
285-
# grad_weight = None
286-
# else:
287-
# raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
288240
else:
289241
if grad.dtype == torch.float32:
290242
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
@@ -385,30 +337,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
385337
),
386338
)
387339
grad_weight = None
388-
# if grad.dtype == torch.float32:
389-
# WeightGradStore.put(
390-
# total_input,
391-
# grad_output,
392-
# weight,
393-
# functools.partial(
394-
# execute_w_pass_grad_accum,
395-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
396-
# ),
397-
# )
398-
# grad_weight = None
399-
# elif grad.dtype in (torch.float16, torch.bfloat16):
400-
# WeightGradStore.put(
401-
# total_input,
402-
# grad_output,
403-
# weight,
404-
# functools.partial(
405-
# execute_w_pass_grad_accum,
406-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
407-
# ),
408-
# )
409-
# grad_weight = None
410-
# else:
411-
# raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
412340
else:
413341
if grad.dtype == torch.float32:
414342
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
@@ -500,30 +428,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
500428
),
501429
)
502430
grad_weight = None
503-
# if grad.dtype == torch.float32:
504-
# WeightGradStore.put(
505-
# total_input,
506-
# grad_output,
507-
# weight,
508-
# functools.partial(
509-
# execute_w_pass_grad_accum,
510-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
511-
# ),
512-
# )
513-
# grad_weight = None
514-
# elif grad.dtype in (torch.float16, torch.bfloat16):
515-
# WeightGradStore.put(
516-
# total_input,
517-
# grad_output,
518-
# weight,
519-
# functools.partial(
520-
# execute_w_pass_grad_accum,
521-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
522-
# ),
523-
# )
524-
# grad_weight = None
525-
# else:
526-
# raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
527431
else:
528432
if grad.dtype == torch.float32:
529433
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
@@ -761,30 +665,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
761665
),
762666
)
763667
grad_weight = None
764-
# if grad.dtype == torch.float32:
765-
# WeightGradStore.put(
766-
# total_input,
767-
# grad_output,
768-
# weight,
769-
# functools.partial(
770-
# execute_w_pass_grad_accum,
771-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
772-
# ),
773-
# )
774-
# grad_weight = None
775-
# elif grad.dtype in (torch.float16, torch.bfloat16):
776-
# WeightGradStore.put(
777-
# total_input,
778-
# grad_output,
779-
# weight,
780-
# functools.partial(
781-
# execute_w_pass_grad_accum,
782-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
783-
# ),
784-
# )
785-
# grad_weight = None
786-
# else:
787-
# raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
788668
else:
789669
if grad.dtype == torch.float32:
790670
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
@@ -972,30 +852,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
972852
),
973853
)
974854
grad_weight = None
975-
# if grad.dtype == torch.float32:
976-
# WeightGradStore.put(
977-
# total_input,
978-
# grad_output,
979-
# weight,
980-
# functools.partial(
981-
# execute_w_pass_grad_accum,
982-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
983-
# ),
984-
# )
985-
# grad_weight = None
986-
# elif grad.dtype in (torch.float16, torch.bfloat16):
987-
# WeightGradStore.put(
988-
# total_input,
989-
# grad_output,
990-
# weight,
991-
# functools.partial(
992-
# execute_w_pass_grad_accum,
993-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
994-
# ),
995-
# )
996-
# grad_weight = None
997-
# else:
998-
# raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
999855
else:
1000856
if grad.dtype == torch.float32:
1001857
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
@@ -1169,30 +1025,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
11691025
),
11701026
)
11711027
grad_weight = None
1172-
# if grad.dtype == torch.float32:
1173-
# WeightGradStore.put(
1174-
# total_input,
1175-
# grad_output,
1176-
# (weight, weight_origin),
1177-
# functools.partial(
1178-
# execute_w_pass_grad_accum,
1179-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
1180-
# ),
1181-
# )
1182-
# grad_weight = None
1183-
# elif grad.dtype in (torch.float16, torch.bfloat16):
1184-
# WeightGradStore.put(
1185-
# total_input,
1186-
# grad_output,
1187-
# (weight, weight_origin),
1188-
# functools.partial(
1189-
# execute_w_pass_grad_accum,
1190-
# wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
1191-
# ),
1192-
# )
1193-
# grad_weight = None
1194-
# else:
1195-
# raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
11961028
else:
11971029
if grad.dtype == torch.float32:
11981030
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)

0 commit comments

Comments
 (0)