Skip to content

Commit 9710f88

Browse files
committed
add bias variants to Helion gemm kernels
1 parent a0f65f0 commit 9710f88

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

examples/matmul.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# static_shapes=True gives a performance boost for matmuls
1111
@helion.kernel(static_shapes=True)
12-
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
12+
def matmul_no_bias(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1313
m, k = x.size()
1414
k2, n = y.size()
1515
assert k == k2, f"size mismatch {k} != {k2}"
@@ -24,10 +24,45 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2424
return out
2525

2626

27+
@helion.kernel(static_shapes=True)
28+
def matmul_with_bias(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
29+
m, k = x.size()
30+
k2, n = y.size()
31+
assert k == k2, f"size mismatch {k} != {k2}"
32+
bias_size = bias.size(0)
33+
assert bias_size == n, f"bias size mismatch, expected {n}, got {bias_size}"
34+
out = torch.empty(
35+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
36+
)
37+
for tile_m, tile_n in hl.tile([m, n]):
38+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
39+
for tile_k in hl.tile(k):
40+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
41+
# Add bias
42+
acc = acc + bias[tile_n]
43+
out[tile_m, tile_n] = acc
44+
return out
45+
46+
47+
def matmul(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor:
48+
"""Wrapper function for tritonbench that dispatches based on bias presence."""
49+
if bias is None:
50+
return matmul_no_bias(x, y)
51+
else:
52+
return matmul_with_bias(x, y, bias)
53+
54+
2755
def check(m: int, k: int, n: int) -> None:
2856
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
2957
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
30-
run_example(matmul, torch.matmul, (x, y))
58+
59+
# Test without bias
60+
run_example(matmul_no_bias, torch.matmul, (x, y))
61+
62+
# Test with bias
63+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
64+
expected_with_bias = lambda x, y, bias: torch.matmul(x, y) + bias
65+
run_example(matmul_with_bias, expected_with_bias, (x, y, bias))
3166

3267

3368
def main() -> None:

examples/matmul_split_k.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# static_shapes=True gives a performance boost for matmuls
1212
@helion.kernel(static_shapes=True)
13-
def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
13+
def matmul_split_k_no_bias(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1414
m, k = x.size()
1515
k2, n = y.size()
1616
assert k == k2, f"size mismatch {k} != {k2}"
@@ -27,10 +27,46 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2727
return out
2828

2929

30+
@helion.kernel(static_shapes=True)
31+
def matmul_split_k_with_bias(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
32+
m, k = x.size()
33+
k2, n = y.size()
34+
assert k == k2, f"size mismatch {k} != {k2}"
35+
bias_size = bias.size(0)
36+
assert bias_size == n, f"bias size mismatch, expected {n}, got {bias_size}"
37+
38+
# Initialize output with bias instead of zeros
39+
out = bias.expand(m, n).contiguous()
40+
41+
split_k = hl.register_tunable("split_k", PowerOfTwoFragment(1, 256))
42+
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
43+
for tile_m, tile_n, outer_k in hl.tile([m, n, k], block_size=[None, None, k_block]):
44+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
45+
for inner_k in hl.tile(outer_k.begin, outer_k.end):
46+
acc = torch.addmm(acc, x[tile_m, inner_k], y[inner_k, tile_n])
47+
hl.atomic_add(out, [tile_m, tile_n], acc)
48+
return out
49+
50+
51+
def matmul_split_k(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor:
52+
"""Wrapper function for tritonbench that dispatches based on bias presence."""
53+
if bias is None:
54+
return matmul_split_k_no_bias(x, y)
55+
else:
56+
return matmul_split_k_with_bias(x, y, bias)
57+
58+
3059
def check(m: int, k: int, n: int) -> None:
3160
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
3261
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
33-
run_example(matmul_split_k, torch.matmul, (x, y), atol=1)
62+
63+
# Test without bias
64+
run_example(matmul_split_k_no_bias, torch.matmul, (x, y), atol=1)
65+
66+
# Test with bias
67+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
68+
expected_with_bias = lambda x, y, bias: torch.matmul(x, y) + bias
69+
run_example(matmul_split_k_with_bias, expected_with_bias, (x, y, bias), atol=1)
3470

3571

3672
def main() -> None:

0 commit comments

Comments
 (0)