@@ -964,7 +964,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
964
964
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
965
965
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
966
966
967
- def matmul(x: torch. Tensor, y: torch. Tensor, *, _launcher=_default_launcher):
967
+ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc , *, _launcher=_default_launcher):
968
968
m, k = x.size()
969
969
k2, n = y.size()
970
970
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1131,9 +1131,13 @@ def _matmul_split_k_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1
1131
1131
load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
1132
1132
load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
1133
1133
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1134
+ eq = offset_2 == 0
1135
+ if eq:
1136
+ acc_copy_1 = acc
1137
+ acc = acc_copy_1
1134
1138
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
1135
1139
1136
- def matmul_split_k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1140
+ def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1137
1141
m, k = x.size()
1138
1142
k2, n = y.size()
1139
1143
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1610,7 +1614,7 @@ from helion.runtime import default_launcher as _default_launcher
1610
1614
import test.test_examples as _global_source0
1611
1615
1612
1616
@triton.jit
1613
- def _matmul_with_epilogue_kernel (x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1617
+ def _matmul_kernel (x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1614
1618
num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
1615
1619
num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
1616
1620
inner_2d_pid = tl.program_id(0)
@@ -1640,15 +1644,15 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t
1640
1644
v_4 = v_3.to(tl.float16)
1641
1645
tl.store(out + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_4, None)
1642
1646
1643
- def matmul_with_epilogue (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1647
+ def matmul (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc , *, _launcher=_default_launcher):
1644
1648
m, k = x.size()
1645
1649
k2, n = y.size()
1646
1650
assert k == k2, f'size mismatch {k} != {k2}'
1647
1651
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1648
1652
_BLOCK_SIZE_0 = 64
1649
1653
_BLOCK_SIZE_1 = 64
1650
1654
_BLOCK_SIZE_2 = 16
1651
- _launcher(_matmul_with_epilogue_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1655
+ _launcher(_matmul_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1652
1656
return out
1653
1657
1654
1658
--- assertExpectedJournal(TestExamples.test_template_via_closure1)
@@ -1663,7 +1667,7 @@ from helion.runtime import default_launcher as _default_launcher
1663
1667
import test.test_examples as _global_source0
1664
1668
1665
1669
@triton.jit
1666
- def _matmul_with_epilogue_kernel (x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1670
+ def _matmul_kernel (x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1667
1671
num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
1668
1672
num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
1669
1673
inner_2d_pid = tl.program_id(0)
@@ -1690,15 +1694,15 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t
1690
1694
v_4 = v_3.to(tl.float16)
1691
1695
tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_4, boundary_check=[0, 1])
1692
1696
1693
- def matmul_with_epilogue (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1697
+ def matmul (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc , *, _launcher=_default_launcher):
1694
1698
m, k = x.size()
1695
1699
k2, n = y.size()
1696
1700
assert k == k2, f'size mismatch {k} != {k2}'
1697
1701
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1698
1702
_BLOCK_SIZE_0 = 64
1699
1703
_BLOCK_SIZE_1 = 64
1700
1704
_BLOCK_SIZE_2 = 16
1701
- _launcher(_matmul_with_epilogue_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1705
+ _launcher(_matmul_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, epilogue.__closure__[0].cell_contents, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1702
1706
return out
1703
1707
1704
1708
--- assertExpectedJournal(TestExamples.test_template_via_closure2)
@@ -1713,7 +1717,7 @@ from helion.runtime import default_launcher as _default_launcher
1713
1717
import test.test_examples as _global_source0
1714
1718
1715
1719
@triton.jit
1716
- def _matmul_with_epilogue_kernel (x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1720
+ def _matmul_kernel (x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1717
1721
num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0)
1718
1722
num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1)
1719
1723
inner_2d_pid = tl.program_id(0)
@@ -1737,13 +1741,13 @@ def _matmul_with_epilogue_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_
1737
1741
v_2 = v_1.to(tl.float16)
1738
1742
tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_2, boundary_check=[0, 1])
1739
1743
1740
- def matmul_with_epilogue (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor], *, _launcher=_default_launcher):
1744
+ def matmul (x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc , *, _launcher=_default_launcher):
1741
1745
m, k = x.size()
1742
1746
k2, n = y.size()
1743
1747
assert k == k2, f'size mismatch {k} != {k2}'
1744
1748
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1745
1749
_BLOCK_SIZE_0 = 64
1746
1750
_BLOCK_SIZE_1 = 64
1747
1751
_BLOCK_SIZE_2 = 16
1748
- _launcher(_matmul_with_epilogue_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1752
+ _launcher(_matmul_kernel , (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1749
1753
return out
0 commit comments