Skip to content

Commit 8819331

Browse files
authored
Add SequenceType Eq comparison (#482)
1 parent 3bf9414 commit 8819331

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

helion/_compiler/type_propagation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,39 @@ def _compare(self, op: ast.cmpop, left: TypeInfo, right: TypeInfo) -> TypeInfo:
16541654
raise
16551655
except Exception as e:
16561656
raise exc.TorchOpTracingError(e) from e
1657+
if (
1658+
isinstance(left, SequenceType)
1659+
and isinstance(right, SequenceType)
1660+
and isinstance(op, ast.Eq)
1661+
):
1662+
if len(left.element_types) != len(right.element_types):
1663+
return LiteralType(origin=self.origin(), value=False)
1664+
1665+
can_determine_statically = True
1666+
all_elements_equal = True
1667+
for left_elem, right_elem in zip(
1668+
left.element_types, right.element_types, strict=False
1669+
):
1670+
if isinstance(left_elem, LiteralType) and isinstance(
1671+
right_elem, LiteralType
1672+
):
1673+
if left_elem.value != right_elem.value:
1674+
all_elements_equal = False
1675+
break
1676+
elif isinstance(left_elem, (NumericType, LiteralType)) and isinstance(
1677+
right_elem, (NumericType, LiteralType)
1678+
):
1679+
if NumericType.known_equal(left_elem.value, right_elem.value):
1680+
continue
1681+
can_determine_statically = False
1682+
break
1683+
else:
1684+
can_determine_statically = False
1685+
break
1686+
1687+
if can_determine_statically:
1688+
return LiteralType(origin=self.origin(), value=all_elements_equal)
1689+
return SymBoolType.new_unbacked(self.origin())
16571690
raise exc.TypeInferenceError(
16581691
f"{type(op).__name__} not supported on {left!s} and {right!s}"
16591692
)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ pytest
22
typing_extensions
33
pre-commit
44
filecheck
5+
expecttest
6+
numpy

test/test_misc.expected

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,59 @@ def kernel_with_scalar_item(x: torch.Tensor, scalar_tensor: torch.Tensor, *, _la
9191
_launcher(_kernel_with_scalar_item_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), scalar_val, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
9292
return result
9393

94+
--- assertExpectedJournal(TestMisc.test_sequence_assert_static_shapes_False)
95+
from __future__ import annotations
96+
97+
import torch
98+
import triton
99+
import triton.language as tl
100+
from helion.runtime import default_launcher as _default_launcher
101+
102+
@triton.jit
103+
def _kernel_kernel(a, out, a_size_0, a_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
104+
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
105+
pid_0 = tl.program_id(0) % num_blocks_0
106+
offset_0 = pid_0 * _BLOCK_SIZE_0
107+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
108+
mask_0 = indices_0 < a_size_0
109+
load = tl.load(a + indices_0[:, None] * a_stride_0, mask_0[:, None], other=0)
110+
load_1 = tl.load(a + indices_0[:, None] * a_stride_0, mask_0[:, None], other=0)
111+
v_0 = load + load_1
112+
tl.store(out + indices_0[:, None] * out_stride_0, v_0, mask_0[:, None])
113+
114+
def kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher):
115+
assert a.size() == b.size()
116+
out = torch.empty_like(a)
117+
_BLOCK_SIZE_0 = 16
118+
_launcher(_kernel_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * 1,), a, out, a.size(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
119+
return out
120+
121+
--- assertExpectedJournal(TestMisc.test_sequence_assert_static_shapes_True)
122+
from __future__ import annotations
123+
124+
import torch
125+
import triton
126+
import triton.language as tl
127+
from helion.runtime import default_launcher as _default_launcher
128+
129+
@triton.jit
130+
def _kernel_kernel(a, b, out, _BLOCK_SIZE_0: tl.constexpr):
131+
num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
132+
pid_0 = tl.program_id(0) % num_blocks_0
133+
offset_0 = pid_0 * _BLOCK_SIZE_0
134+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
135+
load = tl.load(a + indices_0[:, None] * 1, None)
136+
load_1 = tl.load(b + indices_0[:, None] * 1, None)
137+
v_0 = load + load_1
138+
tl.store(out + indices_0[:, None] * 1, v_0, None)
139+
140+
def kernel(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher):
141+
assert a.size() == b.size()
142+
out = torch.empty_like(a)
143+
_BLOCK_SIZE_0 = 16
144+
_launcher(_kernel_kernel, (triton.cdiv(16, _BLOCK_SIZE_0) * 1,), a, b, out, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
145+
return out
146+
94147
--- assertExpectedJournal(TestMisc.test_tile_block_size_constexpr_fix)
95148
from __future__ import annotations
96149

test/test_misc.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from packaging import version
88
import pytest
99
import torch
10+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
11+
from torch.testing._internal.common_utils import parametrize
1012

1113
import helion
1214
from helion._compat import supports_tensor_descriptor
@@ -415,6 +417,25 @@ def copy_kernel(a: torch.Tensor) -> torch.Tensor:
415417
torch.testing.assert_close(result, args[0])
416418
self.assertExpectedJournal(code)
417419

420+
@parametrize("static_shapes", (True, False))
421+
def test_sequence_assert(self, static_shapes):
422+
@helion.kernel(static_shapes=static_shapes)
423+
def kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
424+
assert a.size() == b.size()
425+
out = torch.empty_like(a)
426+
427+
for tile in hl.tile(a.size()):
428+
out[tile] = a[tile] + b[tile]
429+
return out
430+
431+
a = torch.randn(16, 1, device=DEVICE)
432+
code, result = code_and_output(kernel, (a, a))
433+
torch.testing.assert_close(result, a + a)
434+
self.assertExpectedJournal(code)
435+
436+
437+
instantiate_parametrized_tests(TestMisc)
438+
418439

419440
if __name__ == "__main__":
420441
unittest.main()

0 commit comments

Comments
 (0)