Skip to content

Commit 90e2d69

Browse files
committed
[Examples] Add matmul variants with bias support and tests
- Add wrapper functions for tritonbench dispatch in matmul.py and matmul_split_k.py - Implement bias handling in both matmul and matmul_split_k - Add comprehensive tests in test_examples.py for all matmul variants stack-info: PR: #379, branch: yf225/stack/41
1 parent 9064eda commit 90e2d69

File tree

5 files changed

+106
-52
lines changed

5 files changed

+106
-52
lines changed

examples/matmul.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
import helion.language as hl
810

11+
if TYPE_CHECKING:
12+
from collections.abc import Callable
13+
914

1015
# static_shapes=True gives a performance boost for matmuls
1116
@helion.kernel(static_shapes=True)
12-
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
17+
def matmul(
18+
x: torch.Tensor,
19+
y: torch.Tensor,
20+
epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor] = lambda acc,
21+
tile: acc,
22+
) -> torch.Tensor:
1323
m, k = x.size()
1424
k2, n = y.size()
1525
assert k == k2, f"size mismatch {k} != {k2}"
@@ -20,14 +30,24 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2030
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
2131
for tile_k in hl.tile(k):
2232
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
23-
out[tile_m, tile_n] = acc
33+
out[tile_m, tile_n] = epilogue(acc, [tile_m, tile_n])
2434
return out
2535

2636

2737
def check(m: int, k: int, n: int) -> None:
2838
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
2939
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
30-
run_example(matmul, torch.matmul, (x, y))
40+
41+
# Test without bias
42+
kernel_no_bias = lambda x, y: matmul(x, y) # noqa: E731
43+
expected_no_bias = lambda x, y: torch.matmul(x, y) # noqa: E731
44+
run_example(kernel_no_bias, expected_no_bias, (x, y))
45+
46+
# Test with bias
47+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
48+
kernel_with_bias = lambda x, y: matmul(x, y, lambda acc, tile: acc + bias[tile[1]]) # noqa: E731
49+
expected_with_bias = lambda x, y: torch.matmul(x, y) + bias # noqa: E731
50+
run_example(kernel_with_bias, expected_with_bias, (x, y))
3151

3252

3353
def main() -> None:

examples/matmul_split_k.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
from helion.autotuner import PowerOfTwoFragment
810
import helion.language as hl
911

12+
if TYPE_CHECKING:
13+
from collections.abc import Callable
14+
1015

1116
# static_shapes=True gives a performance boost for matmuls
1217
@helion.kernel(static_shapes=True)
13-
def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
18+
def matmul_split_k(
19+
x: torch.Tensor,
20+
y: torch.Tensor,
21+
epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor] = lambda acc,
22+
tile: acc,
23+
) -> torch.Tensor:
1424
m, k = x.size()
1525
k2, n = y.size()
1626
assert k == k2, f"size mismatch {k} != {k2}"
@@ -23,14 +33,29 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2333
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
2434
for inner_k in hl.tile(outer_k.begin, outer_k.end):
2535
acc = torch.addmm(acc, x[tile_m, inner_k], y[inner_k, tile_n])
36+
# Apply epilogue only on the first k-split iteration
37+
if outer_k.begin == 0:
38+
acc = epilogue(acc, [tile_m, tile_n])
2639
hl.atomic_add(out, [tile_m, tile_n], acc)
2740
return out
2841

2942

3043
def check(m: int, k: int, n: int) -> None:
3144
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
3245
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
33-
run_example(matmul_split_k, torch.matmul, (x, y), atol=1)
46+
47+
# Test without bias
48+
kernel_no_bias = lambda x, y: matmul_split_k(x, y) # noqa: E731
49+
expected_no_bias = lambda x, y: torch.matmul(x, y) # noqa: E731
50+
run_example(kernel_no_bias, expected_no_bias, (x, y), atol=1)
51+
52+
# Test with bias using closure approach
53+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
54+
kernel_with_bias = lambda x, y: matmul_split_k( # noqa: E731
55+
x, y, epilogue=lambda acc, tile: acc + bias[tile[1]]
56+
)
57+
expected_with_bias = lambda x, y: torch.matmul(x, y) + bias # noqa: E731
58+
run_example(kernel_with_bias, expected_with_bias, (x, y), atol=1)
3459

3560

3661
def main() -> None:

test/test_examples.expected

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
969969
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
970970
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
971971

972-
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
972+
def matmul(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
973973
m, k = x.size()
974974
k2, n = y.size()
975975
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1136,9 +1136,13 @@ def _matmul_split_k_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1
11361136
load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
11371137
load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
11381138
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1139+
eq = offset_2 == 0
1140+
if eq:
1141+
acc_copy_1 = acc
1142+
acc = acc_copy_1
11391143
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
11401144

1141-
def matmul_split_k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1145+
def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
11421146
m, k = x.size()
11431147
k2, n = y.size()
11441148
assert k == k2, f'size mismatch {k} != {k2}'

test/test_matmul.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.con
7575
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
7676
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
7777

78-
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
78+
def matmul(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
7979
m, k = x.size()
8080
k2, n = y.size()
8181
assert k == k2, f'size mismatch {k} != {k2}'
@@ -162,7 +162,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
162162
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
163163
tl.store(tl.make_block_ptr(out, [128, 128], [128, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), acc, boundary_check=[0, 1])
164164

165-
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
165+
def matmul(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
166166
m, k = x.size()
167167
k2, n = y.size()
168168
assert k == k2, f'size mismatch {k} != {k2}'
@@ -435,7 +435,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
435435
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
436436
out_desc.store([offset_0, offset_1], acc)
437437

438-
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
438+
def matmul(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
439439
m, k = x.size()
440440
k2, n = y.size()
441441
assert k == k2, f'size mismatch {k} != {k2}'

0 commit comments

Comments
 (0)