Skip to content

Commit 8fd5a4b

Browse files
authored
[Examples] Add matmul variants with bias support and tests (#379)
- Add wrapper functions for tritonbench dispatch in matmul.py and matmul_split_k.py - Implement bias handling in both matmul and matmul_split_k - Add comprehensive tests in test_examples.py for all matmul variants
1 parent d66e5c3 commit 8fd5a4b

File tree

7 files changed

+154
-143
lines changed

7 files changed

+154
-143
lines changed

examples/matmul.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import torch
6+
from torch import Tensor
47

58
import helion
69
from helion._testing import run_example
710
import helion.language as hl
811

12+
if TYPE_CHECKING:
13+
from collections.abc import Callable
14+
915

10-
# static_shapes=True gives a performance boost for matmuls
11-
@helion.kernel(static_shapes=True)
12-
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
16+
@helion.kernel(
17+
# static_shapes=True gives a performance boost for matmuls
18+
static_shapes=True,
19+
)
20+
def matmul(
21+
x: Tensor,
22+
y: Tensor,
23+
epilogue: Callable[[Tensor, list[Tensor]], Tensor] = lambda acc, tile: acc,
24+
) -> Tensor:
1325
m, k = x.size()
1426
k2, n = y.size()
1527
assert k == k2, f"size mismatch {k} != {k2}"
@@ -20,17 +32,57 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2032
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
2133
for tile_k in hl.tile(k):
2234
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
23-
out[tile_m, tile_n] = acc
35+
out[tile_m, tile_n] = epilogue(acc, [tile_m, tile_n])
2436
return out
2537

2638

39+
def autotune(m: int, k: int, n: int) -> None:
40+
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
41+
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
42+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
43+
args = (x, y, lambda acc, tile: torch.relu(acc + bias[tile[1]]))
44+
best_config = matmul.autotune(args, force=True)
45+
print(f"Best config: {best_config}")
46+
best_config.save("best_config.json")
47+
48+
2749
def check(m: int, k: int, n: int) -> None:
2850
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
2951
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
52+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
53+
54+
# Test without bias
3055
run_example(matmul, torch.matmul, (x, y))
3156

57+
# Test with bias
58+
def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
59+
return matmul(x, y, lambda acc, tile: acc + bias[tile[1]])
60+
61+
def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
62+
return torch.nn.functional.linear(x, y.T, bias)
63+
64+
run_example(helion_linear, baseline_linear, (x, y, bias))
65+
66+
# Test more complex epilogue
67+
def epilogue(acc: Tensor, tile: list[Tensor]) -> Tensor:
68+
# The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
69+
return torch.relu(acc + bias[tile[1]])
70+
71+
def kernel_wrapper(x: Tensor, y: Tensor) -> Tensor:
72+
return matmul(x, y, epilogue)
73+
74+
def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
75+
return torch.relu(x @ y + bias)
76+
77+
run_example(
78+
kernel_wrapper,
79+
baseline_wrapper,
80+
(x, y),
81+
)
82+
3283

3384
def main() -> None:
85+
# autotune(1024, 1024, 1024)
3486
check(1024, 1024, 1024)
3587

3688

examples/matmul_split_k.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
from helion.autotuner import PowerOfTwoFragment
810
import helion.language as hl
911

12+
if TYPE_CHECKING:
13+
from collections.abc import Callable
14+
1015

1116
# static_shapes=True gives a performance boost for matmuls
1217
@helion.kernel(static_shapes=True)
13-
def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
18+
def matmul_split_k(
19+
x: torch.Tensor,
20+
y: torch.Tensor,
21+
epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor] = lambda acc,
22+
tile: acc,
23+
) -> torch.Tensor:
1424
m, k = x.size()
1525
k2, n = y.size()
1626
assert k == k2, f"size mismatch {k} != {k2}"
@@ -23,14 +33,29 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2333
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
2434
for inner_k in hl.tile(outer_k.begin, outer_k.end):
2535
acc = torch.addmm(acc, x[tile_m, inner_k], y[inner_k, tile_n])
36+
# Apply epilogue only on the first k-split iteration
37+
if outer_k.begin == 0:
38+
acc = epilogue(acc, [tile_m, tile_n])
2639
hl.atomic_add(out, [tile_m, tile_n], acc)
2740
return out
2841

2942

3043
def check(m: int, k: int, n: int) -> None:
3144
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
3245
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
33-
run_example(matmul_split_k, torch.matmul, (x, y), atol=1)
46+
47+
# Test without bias
48+
kernel_no_bias = lambda x, y: matmul_split_k(x, y) # noqa: E731
49+
expected_no_bias = lambda x, y: torch.matmul(x, y) # noqa: E731
50+
run_example(kernel_no_bias, expected_no_bias, (x, y), atol=1)
51+
52+
# Test with bias using closure approach
53+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
54+
kernel_with_bias = lambda x, y: matmul_split_k( # noqa: E731
55+
x, y, epilogue=lambda acc, tile: acc + bias[tile[1]]
56+
)
57+
expected_with_bias = lambda x, y: torch.nn.functional.linear(x, y.T, bias) # noqa: E731
58+
run_example(kernel_with_bias, expected_with_bias, (x, y), atol=1)
3459

3560

3661
def main() -> None:

examples/template_via_closure.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

test/test_examples.expected

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
964964
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
965965
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
966966

967-
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
967+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
968968
m, k = x.size()
969969
k2, n = y.size()
970970
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1131,9 +1131,13 @@ def _matmul_split_k_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1
11311131
load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
11321132
load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
11331133
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1134+
eq = offset_2 == 0
1135+
if eq:
1136+
acc_copy_1 = acc
1137+
acc = acc_copy_1
11341138
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
11351139

1136-
def matmul_split_k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1140+
def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
11371141
m, k = x.size()
11381142
k2, n = y.size()
11391143
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1610,7 +1614,7 @@ from helion.runtime import default_launcher as _default_launcher
16101614
import test.test_examples as _global_source0
16111615

16121616
@triton.jit
1613-
def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1617+
def _matmul_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
16141618
num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
16151619
num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
16161620
inner_2d_pid = tl.program_id(0)
@@ -1640,15 +1644,15 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t
16401644
v_4 = v_3.to(tl.float16)
16411645
tl.store(out + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_4, None)
16421646

1643-
def matmul_with_epilogue(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1647+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
16441648
m, k = x.size()
16451649
k2, n = y.size()
16461650
assert k == k2, f'size mismatch {k} != {k2}'
16471651
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
16481652
_BLOCK_SIZE_0 = 64
16491653
_BLOCK_SIZE_1 = 64
16501654
_BLOCK_SIZE_2 = 16
1651-
_launcher(_matmul_with_epilogue_kernel, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1655+
_launcher(_matmul_kernel, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
16521656
return out
16531657

16541658
--- assertExpectedJournal(TestExamples.test_template_via_closure1)
@@ -1663,7 +1667,7 @@ from helion.runtime import default_launcher as _default_launcher
16631667
import test.test_examples as _global_source0
16641668

16651669
@triton.jit
1666-
def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1670+
def _matmul_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
16671671
num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
16681672
num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
16691673
inner_2d_pid = tl.program_id(0)
@@ -1690,15 +1694,15 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t
16901694
v_4 = v_3.to(tl.float16)
16911695
tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_4, boundary_check=[0, 1])
16921696

1693-
def matmul_with_epilogue(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1697+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
16941698
m, k = x.size()
16951699
k2, n = y.size()
16961700
assert k == k2, f'size mismatch {k} != {k2}'
16971701
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
16981702
_BLOCK_SIZE_0 = 64
16991703
_BLOCK_SIZE_1 = 64
17001704
_BLOCK_SIZE_2 = 16
1701-
_launcher(_matmul_with_epilogue_kernel, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1705+
_launcher(_matmul_kernel, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
17021706
return out
17031707

17041708
--- assertExpectedJournal(TestExamples.test_template_via_closure2)
@@ -1713,7 +1717,7 @@ from helion.runtime import default_launcher as _default_launcher
17131717
import test.test_examples as _global_source0
17141718

17151719
@triton.jit
1716-
def _matmul_with_epilogue_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1720+
def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
17171721
num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
17181722
num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
17191723
inner_2d_pid = tl.program_id(0)
@@ -1737,13 +1741,13 @@ def _matmul_with_epilogue_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_
17371741
v_2 = v_1.to(tl.float16)
17381742
tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_2, boundary_check=[0, 1])
17391743

1740-
def matmul_with_epilogue(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1744+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
17411745
m, k = x.size()
17421746
k2, n = y.size()
17431747
assert k == k2, f'size mismatch {k} != {k2}'
17441748
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
17451749
_BLOCK_SIZE_0 = 64
17461750
_BLOCK_SIZE_1 = 64
17471751
_BLOCK_SIZE_2 = 16
1748-
_launcher(_matmul_with_epilogue_kernel, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1752+
_launcher(_matmul_kernel, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
17491753
return out

test/test_examples.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ def test_template_via_closure0(self):
143143
)
144144
self.assertExpectedJournal(
145145
check_example(
146-
"template_via_closure",
146+
"matmul",
147147
args,
148148
torch.relu(args[0] @ args[1] + bias),
149-
fn_name="matmul_with_epilogue",
149+
fn_name="matmul",
150150
block_sizes=[64, 64, 16],
151151
loop_orders=[[0, 1]],
152152
num_warps=2,
@@ -165,10 +165,10 @@ def test_template_via_closure1(self):
165165
)
166166
self.assertExpectedJournal(
167167
check_example(
168-
"template_via_closure",
168+
"matmul",
169169
args,
170170
torch.relu(args[0] @ args[1] + bias),
171-
fn_name="matmul_with_epilogue",
171+
fn_name="matmul",
172172
block_sizes=[64, 64, 16],
173173
loop_orders=[[0, 1]],
174174
num_warps=2,
@@ -186,10 +186,10 @@ def test_template_via_closure2(self):
186186
)
187187
self.assertExpectedJournal(
188188
check_example(
189-
"template_via_closure",
189+
"matmul",
190190
args,
191191
torch.relu(args[0] @ args[1]),
192-
fn_name="matmul_with_epilogue",
192+
fn_name="matmul",
193193
block_sizes=[64, 64, 16],
194194
loop_orders=[[0, 1]],
195195
num_warps=2,

test/test_matmul.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.con
7575
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
7676
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
7777

78-
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
78+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
7979
m, k = x.size()
8080
k2, n = y.size()
8181
assert k == k2, f'size mismatch {k} != {k2}'
@@ -162,7 +162,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
162162
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
163163
tl.store(tl.make_block_ptr(out, [128, 128], [128, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), acc, boundary_check=[0, 1])
164164

165-
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
165+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
166166
m, k = x.size()
167167
k2, n = y.size()
168168
assert k == k2, f'size mismatch {k} != {k2}'
@@ -435,7 +435,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
435435
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
436436
out_desc.store([offset_0, offset_1], acc)
437437

438-
def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
438+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
439439
m, k = x.size()
440440
k2, n = y.size()
441441
assert k == k2, f'size mismatch {k} != {k2}'

0 commit comments

Comments
 (0)