Skip to content

Commit f4bf7cb

Browse files
authored
Add backward pass support to addmm and gemm operators
Differential Revision: D84263978 Pull Request resolved: #531
1 parent b03df33 commit f4bf7cb

File tree

6 files changed

+171
-20
lines changed

6 files changed

+171
-20
lines changed

tritonbench/operators/addmm/operator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tritonbench.utils.triton_op import (
2121
BenchmarkOperator,
2222
BenchmarkOperatorMetrics,
23+
Mode,
2324
PRECISION_DTYPE_MAPPING,
2425
register_benchmark,
2526
register_metric,
@@ -81,7 +82,6 @@
8182
class Operator(BenchmarkOperator):
8283
DEFAULT_METRICS = ["tflops", "best_config"]
8384
DEFAULT_PRECISION = "fp16"
84-
FWD_ONLY = True
8585

8686
def __init__(
8787
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
@@ -171,6 +171,7 @@ def get_input_iter(self) -> Generator:
171171
if hasattr(self, "dtypes") and self.dtypes:
172172
self.tb_args.precision = "bypass"
173173
self.dtype = PRECISION_DTYPE_MAPPING[self.dtypes[shape_id]]
174+
requires_grad = self.mode in (Mode.BWD, Mode.FWD_BWD)
174175
if hasattr(self, "strides"):
175176
# generate shapes with different strides
176177
strides = self.strides[shape_id]
@@ -188,13 +189,13 @@ def get_input_iter(self) -> Generator:
188189
original_n = max(n, strides[2][0])
189190
a = torch.randn(
190191
(m, n), device=self.device, dtype=self.dtype
191-
).requires_grad_(False)
192+
).requires_grad_(requires_grad)
192193
mat1 = torch.randn(
193194
(original_m, original_k), device=self.device, dtype=self.dtype
194-
).requires_grad_(False)
195+
).requires_grad_(requires_grad)
195196
mat2 = torch.randn(
196197
(original_k, original_n), device=self.device, dtype=self.dtype
197-
).requires_grad_(False)
198+
).requires_grad_(requires_grad)
198199
a = a.as_strided((m, n), strides[0])
199200
mat1 = mat1.as_strided((m, k), strides[1])
200201
mat2 = mat2.as_strided((k, n), strides[2])
@@ -203,13 +204,13 @@ def get_input_iter(self) -> Generator:
203204
m, k, n = shape
204205
a = torch.randn(
205206
(m, n), device=self.device, dtype=self.dtype
206-
).requires_grad_(False)
207+
).requires_grad_(requires_grad)
207208
mat1 = torch.randn(
208209
(m, k), device=self.device, dtype=self.dtype
209-
).requires_grad_(False)
210+
).requires_grad_(requires_grad)
210211
mat2 = torch.randn(
211212
(k, n), device=self.device, dtype=self.dtype
212-
).requires_grad_(False)
213+
).requires_grad_(requires_grad)
213214
if self.col_major:
214215
mat2 = mat2.T.contiguous().T
215216
yield a, mat1, mat2

tritonbench/operators/gemm/kernels/matmul.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,13 @@ def forward(
414414
fp8_fast_accum=True,
415415
output_dtype=None,
416416
):
417+
# Save tensors for backward
418+
ctx.save_for_backward(a, b)
419+
ctx.acc_dtype = acc_dtype
420+
ctx.input_precision = input_precision
421+
ctx.fp8_fast_accum = fp8_fast_accum
422+
ctx.output_dtype = output_dtype
423+
417424
return _matmul._call(
418425
a,
419426
b,
@@ -423,5 +430,32 @@ def forward(
423430
output_dtype=output_dtype,
424431
)
425432

433+
@staticmethod
434+
def backward(ctx, grad_output):
435+
a, b = ctx.saved_tensors
436+
grad_a = grad_b = None
437+
438+
if ctx.needs_input_grad[0]:
439+
grad_a = _matmul._call(
440+
grad_output,
441+
b.t(),
442+
acc_dtype=ctx.acc_dtype,
443+
input_precision=ctx.input_precision,
444+
fp8_fast_accum=ctx.fp8_fast_accum,
445+
output_dtype=None,
446+
)
447+
448+
if ctx.needs_input_grad[1]:
449+
grad_b = _matmul._call(
450+
a.t(),
451+
grad_output,
452+
acc_dtype=ctx.acc_dtype,
453+
input_precision=ctx.input_precision,
454+
fp8_fast_accum=ctx.fp8_fast_accum,
455+
output_dtype=None,
456+
)
457+
458+
return grad_a, grad_b, None, None, None, None
459+
426460

427461
matmul = _matmul.apply

tritonbench/operators/gemm/operator.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _tlx_matmul(*args, **kwargs):
4747
BenchmarkOperator,
4848
BenchmarkOperatorMetrics,
4949
llama_shapes,
50+
Mode,
5051
PRECISION_DTYPE_MAPPING,
5152
register_benchmark,
5253
register_metric,
@@ -176,7 +177,6 @@ def read_shapes_from_csv(csv_path: str) -> List[List[int]]:
176177
class Operator(BenchmarkOperator):
177178
DEFAULT_METRICS = ["latency", "speedup", "tflops"]
178179
DEFAULT_PRECISION = "fp16"
179-
FWD_ONLY = True
180180

181181
def __init__(
182182
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
@@ -543,6 +543,7 @@ def get_input_iter(self) -> Generator:
543543
if hasattr(self, "dtypes") and self.dtypes:
544544
self.tb_args.precision = "bypass"
545545
self.dtype = PRECISION_DTYPE_MAPPING[self.dtypes[shape_id]]
546+
requires_grad = self.mode in (Mode.BWD, Mode.FWD_BWD)
546547
if hasattr(self, "strides"):
547548
strides = self.strides[shape_id]
548549
assert (
@@ -558,28 +559,32 @@ def get_input_iter(self) -> Generator:
558559
actual_n = max(n, strides[1][0])
559560
a = self._scaled_randn(
560561
(actual_m, actual_k), scale=k, device=self.device, dtype=self.dtype
561-
)
562+
).requires_grad_(requires_grad)
562563
w = self._scaled_randn(
563564
(actual_k, actual_n), scale=k, device=self.device, dtype=self.dtype
565+
).requires_grad_(requires_grad)
566+
a = a.as_strided(size=[m, k], stride=strides[0]).requires_grad_(
567+
requires_grad
568+
)
569+
w = w.as_strided(size=[k, n], stride=strides[1]).requires_grad_(
570+
requires_grad
564571
)
565-
a = a.as_strided(size=[m, k], stride=strides[0])
566-
w = w.as_strided(size=[k, n], stride=strides[1])
567572
else:
568573
a = self._scaled_randn(
569574
(m, k), scale=k, device=self.device, dtype=self.dtype
570-
)
575+
).requires_grad_(requires_grad)
571576
w = self._scaled_randn(
572577
(k, n), scale=k, device=self.device, dtype=self.dtype
573-
)
578+
).requires_grad_(requires_grad)
574579
# Convert inputs to column-major if layout is "n" (non-transposed)
575580
if self.layout[0] == "n":
576-
a = a.T.contiguous().T
581+
a = a.T.contiguous().T.requires_grad_(requires_grad)
577582
if self.layout[1] == "n":
578-
w = w.T.contiguous().T
583+
w = w.T.contiguous().T.requires_grad_(requires_grad)
579584
if not bias == None:
580585
bias = torch.randn(
581586
(bias), device=self.device, dtype=self.dtype
582-
).requires_grad_(False)
587+
).requires_grad_(requires_grad)
583588

584589
yield a, w, bias
585590

tritonbench/operators/gemm/partition_k.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def torch_reduction(c_buf, a):
220220
compiled_reduction = torch.compile(torch_reduction)
221221

222222

223-
def matmul_partition_k(a, b, triton_reduce=False):
223+
def _matmul_partition_k_impl(a, b, triton_reduce=False):
224224
# Check constraints.
225225
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
226226
assert a.is_contiguous(), "Matrix A must be contiguous"
@@ -284,3 +284,33 @@ def matmul_partition_k(a, b, triton_reduce=False):
284284
return c
285285
else:
286286
return compiled_reduction(c_buf, a)
287+
288+
289+
class _PartitionKMatmul(torch.autograd.Function):
290+
@staticmethod
291+
def forward(ctx, a, b, triton_reduce=False):
292+
# Save tensors for backward
293+
ctx.save_for_backward(a, b)
294+
ctx.triton_reduce = triton_reduce
295+
return _matmul_partition_k_impl(a, b, triton_reduce)
296+
297+
@staticmethod
298+
def backward(ctx, grad_output):
299+
a, b = ctx.saved_tensors
300+
grad_a = grad_b = None
301+
302+
if ctx.needs_input_grad[0]:
303+
grad_a = _matmul_partition_k_impl(
304+
grad_output, b.t().contiguous(), ctx.triton_reduce
305+
)
306+
307+
if ctx.needs_input_grad[1]:
308+
grad_b = _matmul_partition_k_impl(
309+
a.t().contiguous(), grad_output, ctx.triton_reduce
310+
)
311+
312+
return grad_a, grad_b, None
313+
314+
315+
def matmul_partition_k(a, b, triton_reduce=False):
316+
return _PartitionKMatmul.apply(a, b, triton_reduce)

tritonbench/operators/gemm/stream_k.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def streamk_amd_gemm(
287287
start_iter = end_iter
288288

289289

290-
def streamk_amd_matmul(a, b, bias=None):
290+
def _streamk_amd_matmul_impl(a, b, bias=None):
291291
M, K = a.shape
292292
_, N = b.shape
293293
dtype = a.dtype
@@ -391,6 +391,36 @@ def streamk_amd_matmul(a, b, bias=None):
391391
return c
392392

393393

394+
class _StreamKAmdMatmul(torch.autograd.Function):
395+
@staticmethod
396+
def forward(ctx, a, b, bias=None):
397+
# Save tensors for backward
398+
ctx.save_for_backward(a, b, bias)
399+
return _streamk_amd_matmul_impl(a, b, bias)
400+
401+
@staticmethod
402+
def backward(ctx, grad_output):
403+
a, b, bias = ctx.saved_tensors
404+
grad_a = grad_b = grad_bias = None
405+
406+
if ctx.needs_input_grad[0]:
407+
grad_a = _streamk_amd_matmul_impl(grad_output, b.t().contiguous())
408+
409+
if ctx.needs_input_grad[1]:
410+
grad_b = _streamk_amd_matmul_impl(a.t().contiguous(), grad_output)
411+
412+
if ctx.needs_input_grad[2] and bias is not None:
413+
grad_bias = grad_output.sum(dim=0)
414+
if bias.dim() == 2:
415+
grad_bias = grad_bias.unsqueeze(0)
416+
417+
return grad_a, grad_b, grad_bias
418+
419+
420+
def streamk_amd_matmul(a, b, bias=None):
421+
return _StreamKAmdMatmul.apply(a, b, bias)
422+
423+
394424
def _matmul_launch_metadata(grid, kernel, args):
395425
ret = {}
396426
M, N, K = args["M"], args["N"], args["K"]
@@ -601,7 +631,7 @@ def streamk_cuda_gemm(
601631
c_desc.atomic_add([offs_am, offs_bn], c)
602632

603633

604-
def streamk_cuda_matmul(a, b):
634+
def _streamk_cuda_matmul_impl(a, b):
605635
assert a.dtype == b.dtype, "Incompatible dtypes"
606636

607637
M, K = a.shape
@@ -649,3 +679,28 @@ def grid(META):
649679
NUM_SMS=num_sms, #
650680
)
651681
return c
682+
683+
684+
class _StreamKCudaMatmul(torch.autograd.Function):
685+
@staticmethod
686+
def forward(ctx, a, b):
687+
# Save tensors for backward
688+
ctx.save_for_backward(a, b)
689+
return _streamk_cuda_matmul_impl(a, b)
690+
691+
@staticmethod
692+
def backward(ctx, grad_output):
693+
a, b = ctx.saved_tensors
694+
grad_a = grad_b = None
695+
696+
if ctx.needs_input_grad[0]:
697+
grad_a = _streamk_cuda_matmul_impl(grad_output, b.t().contiguous())
698+
699+
if ctx.needs_input_grad[1]:
700+
grad_b = _streamk_cuda_matmul_impl(a.t().contiguous(), grad_output)
701+
702+
return grad_a, grad_b
703+
704+
705+
def streamk_cuda_matmul(a, b):
706+
return _StreamKCudaMatmul.apply(a, b)

tritonbench/operators/gemm/triton_matmul.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def leaky_relu(x):
140140
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
141141

142142

143-
def matmul(a, b, activation=""):
143+
def _matmul_impl(a, b, activation=""):
144144
# Check constraints.
145145
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
146146
M, K = a.shape
@@ -176,3 +176,29 @@ def matmul(a, b, activation=""):
176176
ENABLE_BUFFER_OPS_ASSUMES=enable_buffer_ops_assumes,
177177
)
178178
return c
179+
180+
181+
class _TritonMatmul(torch.autograd.Function):
182+
@staticmethod
183+
def forward(ctx, a, b, activation=""):
184+
# Save tensors for backward
185+
ctx.save_for_backward(a, b)
186+
ctx.activation = activation
187+
return _matmul_impl(a, b, activation)
188+
189+
@staticmethod
190+
def backward(ctx, grad_output):
191+
a, b = ctx.saved_tensors
192+
grad_a = grad_b = None
193+
194+
if ctx.needs_input_grad[0]:
195+
grad_a = _matmul_impl(grad_output, b.t().contiguous(), "")
196+
197+
if ctx.needs_input_grad[1]:
198+
grad_b = _matmul_impl(a.t().contiguous(), grad_output, "")
199+
200+
return grad_a, grad_b, None
201+
202+
203+
def matmul(a, b, activation=""):
204+
return _TritonMatmul.apply(a, b, activation)

0 commit comments

Comments
 (0)