Skip to content

Commit 07d87cf

Browse files
Jokerenguacamoleo
authored andcommitted
[INTERPRETER] Correct None tensor check logic (triton-lang#5049)
In the interpreter mode, we cannot use `not tensor` to check if `tensor` is None or not because the interpreter directly evaluates the tensor. Also consolidated the test cases for `tl.store`.
1 parent 71145a6 commit 07d87cf

File tree

2 files changed

+18
-31
lines changed

2 files changed

+18
-31
lines changed

python/test/unit/language/test_core.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,47 +1740,34 @@ def kernel(X, Y, Z, N: tl.constexpr):
17401740

17411741
@pytest.mark.interpreter
17421742
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
1743+
@pytest.mark.parametrize("constant_field", ["value", "mask"])
17431744
@pytest.mark.parametrize("num_ctas", num_ctas_list)
1744-
def test_store_constant(dtype_str, num_ctas, device):
1745+
def test_store_constant(num_ctas, dtype_str, constant_field, device):
17451746
check_type_supported(dtype_str, device)
1746-
"""Tests that boolean True is stored as 1"""
17471747

17481748
@triton.jit
1749-
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
1749+
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr):
17501750
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1751-
mask = offsets < n_elements
1752-
output = GENERATE_TEST_HERE
1751+
if CONSTANT_FIELD == "value":
1752+
value = 1
1753+
output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype)
1754+
mask = offsets < n_elements
1755+
elif CONSTANT_FIELD == "mask":
1756+
output = offsets < n_elements
1757+
mask = False
17531758
tl.store(output_ptr + offsets, output, mask=mask)
17541759

1755-
triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str
1756-
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'})
17571760
block_size = 128
17581761
ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device)
17591762
output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device)
1760-
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas)
1761-
1762-
assert torch.all(output == ref)
17631763

1764+
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field)
17641765

1765-
@pytest.mark.interpreter
1766-
@pytest.mark.parametrize("num_ctas", num_ctas_list)
1767-
def test_store_constant_default_dtype(num_ctas, device):
1768-
"""Tests that boolean True is stored as 1"""
1769-
1770-
@triton.jit
1771-
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
1772-
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1773-
mask = offsets < n_elements
1774-
value = 1
1775-
output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype)
1776-
tl.store(output_ptr + offsets, output, mask=mask)
1777-
1778-
block_size = 128
1779-
ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device)
1780-
output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device)
1781-
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas)
1782-
1783-
assert torch.all(output == ref)
1766+
if constant_field == "value":
1767+
print(output, ref)
1768+
assert torch.all(output == ref)
1769+
else:
1770+
assert torch.all(output == 0)
17841771

17851772

17861773
def test_load_store_same_ptr(device):

python/triton/language/semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,7 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
12531253
val = cast(val, elt_ty, builder)
12541254

12551255
# Build IR
1256-
if not mask:
1256+
if mask is None:
12571257
return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
12581258
if not mask.type.scalar.is_bool():
12591259
raise ValueError("Mask must have boolean scalar type")
@@ -1308,7 +1308,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
13081308
if val is not None:
13091309
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
13101310
val = cast(val, ptr.type.scalar.element_ty, builder)
1311-
if not mask:
1311+
if mask is None:
13121312
mask_ir = builder.get_int1(True)
13131313
mask_ty = tl.int1
13141314
if ptr.type.is_block():

0 commit comments

Comments
 (0)