@@ -964,7 +964,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
964
964
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
965
965
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
966
966
967
- def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list [Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
967
+ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple [Tensor, ... ]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
968
968
m, k = x.size()
969
969
k2, n = y.size()
970
970
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1137,7 +1137,7 @@ def _matmul_split_k_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1
1137
1137
acc = acc_copy_1
1138
1138
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
1139
1139
1140
- def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list [torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1140
+ def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, tuple [torch.Tensor, ... ]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1141
1141
m, k = x.size()
1142
1142
k2, n = y.size()
1143
1143
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1644,7 +1644,7 @@ def _matmul_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _
1644
1644
v_4 = v_3.to(tl.float16)
1645
1645
tl.store(out + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_4, None)
1646
1646
1647
- def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list [Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1647
+ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple [Tensor, ... ]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1648
1648
m, k = x.size()
1649
1649
k2, n = y.size()
1650
1650
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1694,7 +1694,7 @@ def _matmul_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _
1694
1694
v_4 = v_3.to(tl.float16)
1695
1695
tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_4, boundary_check=[0, 1])
1696
1696
1697
- def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list [Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1697
+ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple [Tensor, ... ]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1698
1698
m, k = x.size()
1699
1699
k2, n = y.size()
1700
1700
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1741,7 +1741,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
1741
1741
v_2 = v_1.to(tl.float16)
1742
1742
tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_2, boundary_check=[0, 1])
1743
1743
1744
- def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list [Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1744
+ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple [Tensor, ... ]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1745
1745
m, k = x.size()
1746
1746
k2, n = y.size()
1747
1747
assert k == k2, f'size mismatch {k} != {k2}'
0 commit comments