@@ -287,7 +287,7 @@ def streamk_amd_gemm(
287
287
start_iter = end_iter
288
288
289
289
290
- def streamk_amd_matmul (a , b , bias = None ):
290
+ def _streamk_amd_matmul_impl (a , b , bias = None ):
291
291
M , K = a .shape
292
292
_ , N = b .shape
293
293
dtype = a .dtype
@@ -391,6 +391,36 @@ def streamk_amd_matmul(a, b, bias=None):
391
391
return c
392
392
393
393
394
+ class _StreamKAmdMatmul (torch .autograd .Function ):
395
+ @staticmethod
396
+ def forward (ctx , a , b , bias = None ):
397
+ # Save tensors for backward
398
+ ctx .save_for_backward (a , b , bias )
399
+ return _streamk_amd_matmul_impl (a , b , bias )
400
+
401
+ @staticmethod
402
+ def backward (ctx , grad_output ):
403
+ a , b , bias = ctx .saved_tensors
404
+ grad_a = grad_b = grad_bias = None
405
+
406
+ if ctx .needs_input_grad [0 ]:
407
+ grad_a = _streamk_amd_matmul_impl (grad_output , b .t ().contiguous ())
408
+
409
+ if ctx .needs_input_grad [1 ]:
410
+ grad_b = _streamk_amd_matmul_impl (a .t ().contiguous (), grad_output )
411
+
412
+ if ctx .needs_input_grad [2 ] and bias is not None :
413
+ grad_bias = grad_output .sum (dim = 0 )
414
+ if bias .dim () == 2 :
415
+ grad_bias = grad_bias .unsqueeze (0 )
416
+
417
+ return grad_a , grad_b , grad_bias
418
+
419
+
420
+ def streamk_amd_matmul (a , b , bias = None ):
421
+ return _StreamKAmdMatmul .apply (a , b , bias )
422
+
423
+
394
424
def _matmul_launch_metadata (grid , kernel , args ):
395
425
ret = {}
396
426
M , N , K = args ["M" ], args ["N" ], args ["K" ]
@@ -601,7 +631,7 @@ def streamk_cuda_gemm(
601
631
c_desc .atomic_add ([offs_am , offs_bn ], c )
602
632
603
633
604
- def streamk_cuda_matmul (a , b ):
634
+ def _streamk_cuda_matmul_impl (a , b ):
605
635
assert a .dtype == b .dtype , "Incompatible dtypes"
606
636
607
637
M , K = a .shape
@@ -649,3 +679,28 @@ def grid(META):
649
679
NUM_SMS = num_sms , #
650
680
)
651
681
return c
682
+
683
+
684
+ class _StreamKCudaMatmul (torch .autograd .Function ):
685
+ @staticmethod
686
+ def forward (ctx , a , b ):
687
+ # Save tensors for backward
688
+ ctx .save_for_backward (a , b )
689
+ return _streamk_cuda_matmul_impl (a , b )
690
+
691
+ @staticmethod
692
+ def backward (ctx , grad_output ):
693
+ a , b = ctx .saved_tensors
694
+ grad_a = grad_b = None
695
+
696
+ if ctx .needs_input_grad [0 ]:
697
+ grad_a = _streamk_cuda_matmul_impl (grad_output , b .t ().contiguous ())
698
+
699
+ if ctx .needs_input_grad [1 ]:
700
+ grad_b = _streamk_cuda_matmul_impl (a .t ().contiguous (), grad_output )
701
+
702
+ return grad_a , grad_b
703
+
704
+
705
+ def streamk_cuda_matmul (a , b ):
706
+ return _StreamKCudaMatmul .apply (a , b )
0 commit comments