Skip to content

Commit 3ade0e2

Browse files
committed
fix TB matmul integration
1 parent 0bf1a16 commit 3ade0e2

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

benchmarks/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@
8585
"gemm": (
8686
"tritonbench.operators.gemm.operator",
8787
[
88-
("examples.matmul", "matmul"),
89-
("examples.matmul_split_k", "matmul_split_k"),
88+
("examples.matmul", "matmul_tritonbench"),
89+
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
9090
],
9191
),
9292
}

examples/matmul.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def check(m: int, k: int, n: int) -> None:
5050
run_example(kernel_with_bias, expected_with_bias, (x, y))
5151

5252

53+
def matmul_tritonbench(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
54+
"""Wrapper for tritonbench compatibility that handles bias."""
55+
if bias is not None:
56+
return matmul(a, b, lambda acc, tile: acc + bias[tile[1]])
57+
else:
58+
return matmul(a, b)
59+
60+
5361
def main() -> None:
5462
check(1024, 1024, 1024)
5563

examples/matmul_split_k.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ def check(m: int, k: int, n: int) -> None:
5858
run_example(kernel_with_bias, expected_with_bias, (x, y), atol=1)
5959

6060

61+
def matmul_split_k_tritonbench(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
62+
"""Wrapper for tritonbench compatibility that handles bias."""
63+
if bias is not None:
64+
return matmul_split_k(a, b, lambda acc, tile: acc + bias[tile[1]])
65+
else:
66+
return matmul_split_k(a, b)
67+
68+
6169
def main() -> None:
6270
check(64, 32768, 64)
6371

0 commit comments

Comments
 (0)