Skip to content

Commit 607c50c

Browse files
authored
[FRONTEND] Fix __eq__ and __add__ methods in tl.tuple (#6857)
* Added two `test_eq` and `test_add` * Introduced a `_normalize_tuple` function to standardize tuple inputs. Otherwise, if `other` is a constexpr, both `__eq__` and `__add__` methods will fail * Removed redundant imports of `builtins` * Fixed a tuple construction issue in the `flip` function that previously requires the `core.tuple([2] * steps)` workaround
1 parent da83833 commit 607c50c

File tree

3 files changed

+46
-10
lines changed

3 files changed

+46
-10
lines changed

python/test/unit/language/test_tuple.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,38 @@ def mul(x, a):
162162
ty = Tensor(y, y.shape, y.stride())
163163
_namedtuple_kernel[(1, )](function, tx, ty, 64, 64)
164164
assert torch.allclose(y, x[:16, :16] * a)
165+
166+
167+
@pytest.mark.interpreter
168+
def test_eq(device):
169+
170+
@triton.jit
171+
def fn(ret_ptrs):
172+
tl.store(ret_ptrs + 0, (1, 2) == (1, 2))
173+
tl.store(ret_ptrs + 1, (1, 2) == (1, 1))
174+
tl.store(ret_ptrs + 2, tl.tuple((1, 2)) == (1, 2))
175+
tl.store(ret_ptrs + 3, tl.tuple((1, 2)) == (1, 3))
176+
177+
rets = torch.zeros((4, ), dtype=torch.int32, device=device)
178+
fn[(1, )](rets)
179+
assert rets[0].item() == 1
180+
assert rets[1].item() == 0
181+
assert rets[2].item() == 1
182+
assert rets[3].item() == 0
183+
184+
185+
@pytest.mark.interpreter
186+
def test_add(device):
187+
188+
@triton.jit
189+
def fn(ret_ptrs):
190+
tuple0 = ((0, 1)) + (2, 3)
191+
for i in tl.static_range(4):
192+
tl.store(ret_ptrs + i, tuple0[i])
193+
tuple1 = tl.tuple((4, 5)) + (6, 7)
194+
for i in tl.static_range(4):
195+
tl.store(ret_ptrs + 4 + i, tuple1[i])
196+
197+
rets = torch.zeros((8, ), dtype=torch.int32, device=device)
198+
fn[(1, )](rets)
199+
torch.testing.assert_close(rets.cpu(), torch.arange(8, dtype=torch.int32))

python/triton/language/core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,13 @@ def _unwrap_if_constexpr(o):
306306
return o.value if isinstance(o, constexpr) else o
307307

308308

309+
def _normalize_tuple(t):
310+
normalized_tuple = _unwrap_if_constexpr(t)
311+
if isinstance(normalized_tuple, (list, builtins.tuple)):
312+
normalized_tuple = tuple(normalized_tuple)
313+
return normalized_tuple
314+
315+
309316
def check_bit_width(value, shift_value):
310317
if isinstance(value, tensor) and isinstance(shift_value, constexpr):
311318
bitwidth = value.type.scalar.primitive_bitwidth
@@ -1069,7 +1076,6 @@ def __not__(self, _builder=None):
10691076

10701077
@builtin
10711078
def __getitem__(self, slices, _builder=None):
1072-
import builtins
10731079
if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
10741080
slices = [slices]
10751081
if isinstance(slices, tuple):
@@ -1237,7 +1243,7 @@ def flip(self, dim=None) -> tensor:
12371243

12381244
class tuple(base_value):
12391245

1240-
def __init__(self, args: list, type: tuple_type = None):
1246+
def __init__(self, args: Sequence, type: tuple_type = None):
12411247
self.values = [i for i in args]
12421248

12431249
def get_type(x):
@@ -1255,7 +1261,6 @@ def __getitem__(self, idx: constexpr):
12551261
if isinstance(idx, constexpr):
12561262
return self.values[idx]
12571263
else:
1258-
import builtins
12591264
assert isinstance(idx, (slice, builtins.slice))
12601265
return tuple(self.values[idx.start:idx.stop:idx.step])
12611266

@@ -1270,8 +1275,7 @@ def __setitem__(self, idx: constexpr, value):
12701275
self.values[idx] = value
12711276

12721277
def __add__(self, other):
1273-
if isinstance(other, list):
1274-
other = tuple(other)
1278+
other = _normalize_tuple(other)
12751279
return tuple(self.values + other.values)
12761280
# return tuple(a + b for a, b in zip(self.values, other.values))
12771281

@@ -1280,13 +1284,10 @@ def __mul__(self, other):
12801284
return tuple(self.values * other.value)
12811285

12821286
def __eq__(self, other):
1283-
import builtins
1284-
if isinstance(other, (list, builtins.tuple)):
1285-
other = tuple(other)
1287+
other = _normalize_tuple(other)
12861288
return constexpr(self.values == other.values)
12871289

12881290
def __hash__(self):
1289-
import builtins
12901291
return hash(builtins.tuple(self.values))
12911292

12921293
def __str__(self):

python/triton/language/standard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def flip(x, dim=None):
498498

499499
# reshape the swap dimension to (2, 2, ..., 2)
500500
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
501-
y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + core.tuple([2] * steps) + x.shape[_dim + 1:])
501+
y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
502502
for i in core.static_range(steps):
503503
y = y ^ xor_sum(y, _dim + i, True)
504504
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)

0 commit comments

Comments
 (0)