Skip to content

Commit a352384

Browse files
authored
Add implicit broadcasting tests (#285)
1 parent c90a4ef commit a352384

File tree

3 files changed

+205
-4
lines changed

3 files changed

+205
-4
lines changed

helion/_compiler/type_propagation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,10 +1209,12 @@ def merge(self, other: TypeInfo) -> TypeInfo:
12091209
if len(self_elements) == len(other_elements):
12101210
return SequenceType(
12111211
origin=other.origin,
1212-
element_types=[
1213-
self_elements[i].merge(other_elements[i])
1214-
for i in range(len(self_elements))
1215-
],
1212+
element_types=self._maybe_tuple(
1213+
[
1214+
self_elements[i].merge(other_elements[i])
1215+
for i in range(len(self_elements))
1216+
]
1217+
),
12161218
)
12171219
return super().merge(other)
12181220

test/test_indexing.expected

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,145 @@ def _arange_three_args_step_make_precompiler(x: torch.Tensor):
5858
from helion.runtime.precompile_shim import make_precompiler
5959
return make_precompiler(_arange_three_args_step_kernel)(out, out.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
6060

61+
--- assertExpectedJournal(TestIndexing.test_broadcasting_block_ptr_indexing)
62+
from __future__ import annotations
63+
64+
import torch
65+
import triton
66+
import triton.language as tl
67+
68+
@triton.jit
69+
def _broadcast_add_3d_kernel(x, bias1, bias2, out, bias1_size_1, bias1_size_2, bias2_size_0, bias2_size_2, out_size_0, out_size_1, out_size_2, x_size_0, x_size_1, x_size_2, bias1_stride_0, bias1_stride_1, bias1_stride_2, bias2_stride_0, bias2_stride_1, bias2_stride_2, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, d0, d1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
70+
num_blocks_0 = tl.cdiv(d0, _BLOCK_SIZE_0)
71+
num_blocks_1 = tl.cdiv(d1, _BLOCK_SIZE_1)
72+
pid_0 = tl.program_id(0) % num_blocks_0
73+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
74+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
75+
offset_0 = pid_0 * _BLOCK_SIZE_0
76+
offset_1 = pid_1 * _BLOCK_SIZE_1
77+
offset_2 = pid_2 * _BLOCK_SIZE_2
78+
load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1, x_size_2], [x_stride_0, x_stride_1, x_stride_2], [offset_0, offset_1, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
79+
load_1 = tl.load(tl.make_block_ptr(bias1, [1, bias1_size_1, bias1_size_2], [bias1_stride_0, bias1_stride_1, bias1_stride_2], [0, offset_1, offset_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), boundary_check=[1, 2], padding_option='zero')
80+
v_0 = load + load_1
81+
load_2 = tl.load(tl.make_block_ptr(bias2, [bias2_size_0, 1, bias2_size_2], [bias2_stride_0, bias2_stride_1, bias2_stride_2], [offset_0, 0, offset_2], [_BLOCK_SIZE_0, 1, _BLOCK_SIZE_2], [2, 1, 0]), boundary_check=[0, 2], padding_option='zero')
82+
v_1 = v_0 + load_2
83+
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, out_size_2], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), v_1, boundary_check=[0, 1, 2])
84+
85+
def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
86+
d0, d1, d2 = x.size()
87+
out = torch.empty_like(x)
88+
_BLOCK_SIZE_0 = 8
89+
_BLOCK_SIZE_1 = 8
90+
_BLOCK_SIZE_2 = 8
91+
_broadcast_add_3d_kernel[triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),](x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
92+
return out
93+
94+
def _broadcast_add_3d_make_precompiler(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
95+
d0, d1, d2 = x.size()
96+
out = torch.empty_like(x)
97+
_BLOCK_SIZE_0 = 8
98+
_BLOCK_SIZE_1 = 8
99+
_BLOCK_SIZE_2 = 8
100+
from helion.runtime.precompile_shim import make_precompiler
101+
return make_precompiler(_broadcast_add_3d_kernel)(x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
102+
103+
--- assertExpectedJournal(TestIndexing.test_broadcasting_pointer_indexing)
104+
from __future__ import annotations
105+
106+
import torch
107+
import triton
108+
import triton.language as tl
109+
110+
@triton.jit
111+
def _broadcast_add_3d_kernel(x, bias1, bias2, out, bias1_stride_1, bias1_stride_2, bias2_stride_0, bias2_stride_2, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, d0, d1, d2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
112+
num_blocks_0 = tl.cdiv(d0, _BLOCK_SIZE_0)
113+
num_blocks_1 = tl.cdiv(d1, _BLOCK_SIZE_1)
114+
pid_0 = tl.program_id(0) % num_blocks_0
115+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
116+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
117+
offset_0 = pid_0 * _BLOCK_SIZE_0
118+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
119+
mask_0 = indices_0 < d0
120+
offset_1 = pid_1 * _BLOCK_SIZE_1
121+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
122+
mask_1 = indices_1 < d1
123+
offset_2 = pid_2 * _BLOCK_SIZE_2
124+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
125+
mask_2 = indices_2 < d2
126+
load = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :], other=0)
127+
load_1 = tl.load(bias1 + (indices_1[None, :, None] * bias1_stride_1 + indices_2[None, None, :] * bias1_stride_2), mask_1[None, :, None] & mask_2[None, None, :], other=0)
128+
v_0 = load + load_1
129+
load_2 = tl.load(bias2 + (indices_0[:, None, None] * bias2_stride_0 + indices_2[None, None, :] * bias2_stride_2), mask_0[:, None, None] & mask_2[None, None, :], other=0)
130+
v_1 = v_0 + load_2
131+
tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_1[None, :, None] * out_stride_1 + indices_2[None, None, :] * out_stride_2), v_1, mask_0[:, None, None] & mask_1[None, :, None] & mask_2[None, None, :])
132+
133+
def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
134+
d0, d1, d2 = x.size()
135+
out = torch.empty_like(x)
136+
_BLOCK_SIZE_0 = 8
137+
_BLOCK_SIZE_1 = 8
138+
_BLOCK_SIZE_2 = 8
139+
_broadcast_add_3d_kernel[triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),](x, bias1, bias2, out, bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, d2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
140+
return out
141+
142+
def _broadcast_add_3d_make_precompiler(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
143+
d0, d1, d2 = x.size()
144+
out = torch.empty_like(x)
145+
_BLOCK_SIZE_0 = 8
146+
_BLOCK_SIZE_1 = 8
147+
_BLOCK_SIZE_2 = 8
148+
from helion.runtime.precompile_shim import make_precompiler
149+
return make_precompiler(_broadcast_add_3d_kernel)(x, bias1, bias2, out, bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, d2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
150+
151+
--- assertExpectedJournal(TestIndexing.test_broadcasting_tensor_descriptor_indexing)
152+
from __future__ import annotations
153+
154+
import torch
155+
import helion
156+
import triton
157+
import triton.language as tl
158+
159+
helion.runtime.set_triton_allocator()
160+
161+
@triton.jit
162+
def _broadcast_add_3d_kernel(x, bias1, bias2, out, bias1_size_1, bias1_size_2, bias2_size_0, bias2_size_2, out_size_0, out_size_1, out_size_2, x_size_0, x_size_1, x_size_2, bias1_stride_0, bias1_stride_1, bias1_stride_2, bias2_stride_0, bias2_stride_1, bias2_stride_2, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, d0, d1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
163+
x_desc = tl.make_tensor_descriptor(x, [x_size_0, x_size_1, x_size_2], [x_stride_0, x_stride_1, x_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2])
164+
bias1_desc = tl.make_tensor_descriptor(bias1, [1, bias1_size_1, bias1_size_2], [bias1_stride_0, bias1_stride_1, bias1_stride_2], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_2])
165+
bias2_desc = tl.make_tensor_descriptor(bias2, [bias2_size_0, 1, bias2_size_2], [bias2_stride_0, bias2_stride_1, bias2_stride_2], [_BLOCK_SIZE_0, 1, _BLOCK_SIZE_2])
166+
out_desc = tl.make_tensor_descriptor(out, [out_size_0, out_size_1, out_size_2], [out_stride_0, out_stride_1, out_stride_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2])
167+
num_blocks_0 = tl.cdiv(d0, _BLOCK_SIZE_0)
168+
num_blocks_1 = tl.cdiv(d1, _BLOCK_SIZE_1)
169+
pid_0 = tl.program_id(0) % num_blocks_0
170+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
171+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
172+
offset_0 = pid_0 * _BLOCK_SIZE_0
173+
offset_1 = pid_1 * _BLOCK_SIZE_1
174+
offset_2 = pid_2 * _BLOCK_SIZE_2
175+
load = x_desc.load([offset_0, offset_1, offset_2])
176+
load_1 = bias1_desc.load([0, offset_1, offset_2])
177+
v_0 = load + load_1
178+
load_2 = bias2_desc.load([offset_0, 0, offset_2])
179+
v_1 = v_0 + load_2
180+
out_desc.store([offset_0, offset_1, offset_2], v_1)
181+
182+
def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
183+
d0, d1, d2 = x.size()
184+
out = torch.empty_like(x)
185+
_BLOCK_SIZE_0 = 8
186+
_BLOCK_SIZE_1 = 8
187+
_BLOCK_SIZE_2 = 8
188+
_broadcast_add_3d_kernel[triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),](x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
189+
return out
190+
191+
def _broadcast_add_3d_make_precompiler(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor):
192+
d0, d1, d2 = x.size()
193+
out = torch.empty_like(x)
194+
_BLOCK_SIZE_0 = 8
195+
_BLOCK_SIZE_1 = 8
196+
_BLOCK_SIZE_2 = 8
197+
from helion.runtime.precompile_shim import make_precompiler
198+
return make_precompiler(_broadcast_add_3d_kernel)(x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
199+
61200
--- assertExpectedJournal(TestIndexing.test_mask_load)
62201
from __future__ import annotations
63202

test/test_indexing.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,29 @@
55
import torch
66

77
import helion
8+
from helion._compat import supports_tensor_descriptor
89
from helion._testing import DEVICE
910
from helion._testing import TestCase
1011
from helion._testing import code_and_output
1112
import helion.language as hl
1213

1314

15+
@helion.kernel
16+
def broadcast_add_3d(
17+
x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor
18+
) -> torch.Tensor:
19+
d0, d1, d2 = x.size()
20+
out = torch.empty_like(x)
21+
for tile_l, tile_m, tile_n in hl.tile([d0, d1, d2]):
22+
# bias1 has shape [1, d1, d2], bias2 has shape [d0, 1, d2]
23+
out[tile_l, tile_m, tile_n] = (
24+
x[tile_l, tile_m, tile_n]
25+
+ bias1[tile_l, tile_m, tile_n]
26+
+ bias2[tile_l, tile_m, tile_n]
27+
)
28+
return out
29+
30+
1431
class TestIndexing(TestCase):
1532
def test_arange(self):
1633
@helion.kernel
@@ -320,6 +337,49 @@ def arange_three_args_step(x: torch.Tensor) -> torch.Tensor:
320337
expected = torch.arange(0, 64, step=2, dtype=torch.int32, device=DEVICE)
321338
torch.testing.assert_close(result, expected)
322339

340+
def test_broadcasting_pointer_indexing(self):
341+
x = torch.randn([16, 24, 32], device=DEVICE)
342+
bias1 = torch.randn([1, 24, 32], device=DEVICE)
343+
bias2 = torch.randn([16, 1, 32], device=DEVICE)
344+
code, result = code_and_output(
345+
broadcast_add_3d,
346+
(x, bias1, bias2),
347+
indexing="pointer",
348+
block_size=[8, 8, 8],
349+
)
350+
expected = x + bias1 + bias2
351+
torch.testing.assert_close(result, expected)
352+
self.assertExpectedJournal(code)
353+
354+
def test_broadcasting_block_ptr_indexing(self):
355+
x = torch.randn([16, 24, 32], device=DEVICE)
356+
bias1 = torch.randn([1, 24, 32], device=DEVICE)
357+
bias2 = torch.randn([16, 1, 32], device=DEVICE)
358+
code, result = code_and_output(
359+
broadcast_add_3d,
360+
(x, bias1, bias2),
361+
indexing="block_ptr",
362+
block_size=[8, 8, 8],
363+
)
364+
expected = x + bias1 + bias2
365+
torch.testing.assert_close(result, expected)
366+
self.assertExpectedJournal(code)
367+
368+
@unittest.skipIf(not supports_tensor_descriptor(), "TensorDescriptor not supported")
369+
def test_broadcasting_tensor_descriptor_indexing(self):
370+
x = torch.randn([16, 24, 32], device=DEVICE)
371+
bias1 = torch.randn([1, 24, 32], device=DEVICE)
372+
bias2 = torch.randn([16, 1, 32], device=DEVICE)
373+
code, result = code_and_output(
374+
broadcast_add_3d,
375+
(x, bias1, bias2),
376+
indexing="tensor_descriptor",
377+
block_size=[8, 8, 8],
378+
)
379+
expected = x + bias1 + bias2
380+
torch.testing.assert_close(result, expected)
381+
self.assertExpectedJournal(code)
382+
323383

324384
if __name__ == "__main__":
325385
unittest.main()

0 commit comments

Comments
 (0)