@@ -142,30 +142,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
142142 ),
143143 )
144144 grad_weight = None
145- # if grad.dtype == torch.float32:
146- # WeightGradStore.put(
147- # total_input,
148- # grad_output,
149- # (weight, weight_origin),
150- # functools.partial(
151- # execute_w_pass_grad_accum,
152- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
153- # ),
154- # )
155- # grad_weight = None
156- # elif grad.dtype in (torch.float16, torch.bfloat16):
157- # WeightGradStore.put(
158- # total_input,
159- # grad_output,
160- # (weight, weight_origin),
161- # functools.partial(
162- # execute_w_pass_grad_accum,
163- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
164- # ),
165- # )
166- # grad_weight = None
167- # else:
168- # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
169145 else :
170146 if grad .dtype == torch .float32 :
171147 fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 (total_input , grad_output , grad )
@@ -261,30 +237,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
261237 ),
262238 )
263239 grad_weight = None
264- # if grad.dtype == torch.float32:
265- # WeightGradStore.put(
266- # total_input,
267- # grad_output,
268- # (weight, weight_origin),
269- # functools.partial(
270- # execute_w_pass_grad_accum,
271- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
272- # ),
273- # )
274- # grad_weight = None
275- # elif grad.dtype in (torch.float16, torch.bfloat16):
276- # WeightGradStore.put(
277- # total_input,
278- # grad_output,
279- # (weight, weight_origin),
280- # functools.partial(
281- # execute_w_pass_grad_accum,
282- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
283- # ),
284- # )
285- # grad_weight = None
286- # else:
287- # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
288240 else :
289241 if grad .dtype == torch .float32 :
290242 fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 (total_input , grad_output , grad )
@@ -385,30 +337,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
385337 ),
386338 )
387339 grad_weight = None
388- # if grad.dtype == torch.float32:
389- # WeightGradStore.put(
390- # total_input,
391- # grad_output,
392- # weight,
393- # functools.partial(
394- # execute_w_pass_grad_accum,
395- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
396- # ),
397- # )
398- # grad_weight = None
399- # elif grad.dtype in (torch.float16, torch.bfloat16):
400- # WeightGradStore.put(
401- # total_input,
402- # grad_output,
403- # weight,
404- # functools.partial(
405- # execute_w_pass_grad_accum,
406- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
407- # ),
408- # )
409- # grad_weight = None
410- # else:
411- # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
412340 else :
413341 if grad .dtype == torch .float32 :
414342 fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 (total_input , grad_output , grad )
@@ -500,30 +428,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
500428 ),
501429 )
502430 grad_weight = None
503- # if grad.dtype == torch.float32:
504- # WeightGradStore.put(
505- # total_input,
506- # grad_output,
507- # weight,
508- # functools.partial(
509- # execute_w_pass_grad_accum,
510- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
511- # ),
512- # )
513- # grad_weight = None
514- # elif grad.dtype in (torch.float16, torch.bfloat16):
515- # WeightGradStore.put(
516- # total_input,
517- # grad_output,
518- # weight,
519- # functools.partial(
520- # execute_w_pass_grad_accum,
521- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
522- # ),
523- # )
524- # grad_weight = None
525- # else:
526- # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
527431 else :
528432 if grad .dtype == torch .float32 :
529433 fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 (total_input , grad_output , grad )
@@ -761,30 +665,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
761665 ),
762666 )
763667 grad_weight = None
764- # if grad.dtype == torch.float32:
765- # WeightGradStore.put(
766- # total_input,
767- # grad_output,
768- # weight,
769- # functools.partial(
770- # execute_w_pass_grad_accum,
771- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
772- # ),
773- # )
774- # grad_weight = None
775- # elif grad.dtype in (torch.float16, torch.bfloat16):
776- # WeightGradStore.put(
777- # total_input,
778- # grad_output,
779- # weight,
780- # functools.partial(
781- # execute_w_pass_grad_accum,
782- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
783- # ),
784- # )
785- # grad_weight = None
786- # else:
787- # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
788668 else :
789669 if grad .dtype == torch .float32 :
790670 fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 (total_input , grad_output , grad )
@@ -972,30 +852,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
972852 ),
973853 )
974854 grad_weight = None
975- # if grad.dtype == torch.float32:
976- # WeightGradStore.put(
977- # total_input,
978- # grad_output,
979- # weight,
980- # functools.partial(
981- # execute_w_pass_grad_accum,
982- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
983- # ),
984- # )
985- # grad_weight = None
986- # elif grad.dtype in (torch.float16, torch.bfloat16):
987- # WeightGradStore.put(
988- # total_input,
989- # grad_output,
990- # weight,
991- # functools.partial(
992- # execute_w_pass_grad_accum,
993- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
994- # ),
995- # )
996- # grad_weight = None
997- # else:
998- # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
999855 else :
1000856 if grad .dtype == torch .float32 :
1001857 fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 (total_input , grad_output , grad )
@@ -1169,30 +1025,6 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
11691025 ),
11701026 )
11711027 grad_weight = None
1172- # if grad.dtype == torch.float32:
1173- # WeightGradStore.put(
1174- # total_input,
1175- # grad_output,
1176- # (weight, weight_origin),
1177- # functools.partial(
1178- # execute_w_pass_grad_accum,
1179- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
1180- # ),
1181- # )
1182- # grad_weight = None
1183- # elif grad.dtype in (torch.float16, torch.bfloat16):
1184- # WeightGradStore.put(
1185- # total_input,
1186- # grad_output,
1187- # (weight, weight_origin),
1188- # functools.partial(
1189- # execute_w_pass_grad_accum,
1190- # wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
1191- # ),
1192- # )
1193- # grad_weight = None
1194- # else:
1195- # raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
11961028 else :
11971029 if grad .dtype == torch .float32 :
11981030 fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 (total_input , grad_output , grad )
0 commit comments