Skip to content

Commit 7d54ca4

Browse files
authored
Fix non-tuple indexing warning (#411)
``` UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result ```
1 parent f5594ca commit 7d54ca4

File tree

6 files changed

+49
-48
lines changed

6 files changed

+49
-48
lines changed

examples/matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def matmul(
2121
x: Tensor,
2222
y: Tensor,
23-
epilogue: Callable[[Tensor, list[Tensor]], Tensor] = lambda acc, tile: acc,
23+
epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor] = lambda acc, tile: acc,
2424
) -> Tensor:
2525
m, k = x.size()
2626
k2, n = y.size()
@@ -32,7 +32,7 @@ def matmul(
3232
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
3333
for tile_k in hl.tile(k):
3434
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
35-
out[tile_m, tile_n] = epilogue(acc, [tile_m, tile_n])
35+
out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n))
3636
return out
3737

3838

@@ -64,7 +64,7 @@ def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
6464
run_example(helion_linear, baseline_linear, (x, y, bias))
6565

6666
# Test more complex epilogue
67-
def epilogue(acc: Tensor, tile: list[Tensor]) -> Tensor:
67+
def epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
6868
# The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
6969
return torch.relu(acc + bias[tile[1]])
7070

examples/matmul_split_k.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
def matmul_split_k(
1919
x: torch.Tensor,
2020
y: torch.Tensor,
21-
epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor] = lambda acc,
22-
tile: acc,
21+
epilogue: Callable[
22+
[torch.Tensor, tuple[torch.Tensor, ...]], torch.Tensor
23+
] = lambda acc, tile: acc,
2324
) -> torch.Tensor:
2425
m, k = x.size()
2526
k2, n = y.size()
@@ -35,7 +36,7 @@ def matmul_split_k(
3536
acc = torch.addmm(acc, x[tile_m, inner_k], y[inner_k, tile_n])
3637
# Apply epilogue only on the first k-split iteration
3738
if outer_k.begin == 0:
38-
acc = epilogue(acc, [tile_m, tile_n])
39+
acc = epilogue(acc, (tile_m, tile_n))
3940
hl.atomic_add(out, [tile_m, tile_n], acc)
4041
return out
4142

helion/language/loops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def _(
364364
if unpack:
365365
(result,) = results
366366
else:
367-
result = SequenceType(origin, results)
367+
result = SequenceType(origin, tuple(results))
368368
return IterType(origin, result)
369369

370370

@@ -712,7 +712,7 @@ def _(
712712
if unpack:
713713
(result,) = results
714714
else:
715-
result = SequenceType(origin, results)
715+
result = SequenceType(origin, tuple(results))
716716
return IterType(origin, result)
717717

718718

test/test_examples.expected

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
964964
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
965965
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
966966

967-
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
967+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
968968
m, k = x.size()
969969
k2, n = y.size()
970970
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1137,7 +1137,7 @@ def _matmul_split_k_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1
11371137
acc = acc_copy_1
11381138
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
11391139

1140-
def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1140+
def matmul_split_k(x: torch.Tensor, y: torch.Tensor, epilogue: Callable[[torch.Tensor, tuple[torch.Tensor, ...]], torch.Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
11411141
m, k = x.size()
11421142
k2, n = y.size()
11431143
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1644,7 +1644,7 @@ def _matmul_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _
16441644
v_4 = v_3.to(tl.float16)
16451645
tl.store(out + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_4, None)
16461646

1647-
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1647+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
16481648
m, k = x.size()
16491649
k2, n = y.size()
16501650
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1694,7 +1694,7 @@ def _matmul_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _
16941694
v_4 = v_3.to(tl.float16)
16951695
tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_4, boundary_check=[0, 1])
16961696

1697-
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1697+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
16981698
m, k = x.size()
16991699
k2, n = y.size()
17001700
assert k == k2, f'size mismatch {k} != {k2}'
@@ -1741,7 +1741,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
17411741
v_2 = v_1.to(tl.float16)
17421742
tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_2, boundary_check=[0, 1])
17431743

1744-
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
1744+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
17451745
m, k = x.size()
17461746
k2, n = y.size()
17471747
assert k == k2, f'size mismatch {k} != {k2}'

test/test_matmul.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.con
7575
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
7676
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
7777

78-
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
78+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
7979
m, k = x.size()
8080
k2, n = y.size()
8181
assert k == k2, f'size mismatch {k} != {k2}'
@@ -162,7 +162,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
162162
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
163163
tl.store(tl.make_block_ptr(out, [128, 128], [128, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), acc, boundary_check=[0, 1])
164164

165-
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
165+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
166166
m, k = x.size()
167167
k2, n = y.size()
168168
assert k == k2, f'size mismatch {k} != {k2}'
@@ -435,7 +435,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con
435435
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
436436
out_desc.store([offset_0, offset_1], acc)
437437

438-
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
438+
def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher):
439439
m, k = x.size()
440440
k2, n = y.size()
441441
assert k == k2, f'size mismatch {k} != {k2}'

0 commit comments

Comments
 (0)