Skip to content

Commit 206c410

Browse files
authored
[INTERPRETER] Fix lower bound check for block pointers (#5201)
We forgot to check `offset >= 0` previously. Now that it should match the semantic in the GPU backend https://github.com/triton-lang/triton/blob/7bce3613755e26953518962d02315dfd343dc50c/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp#L136
1 parent d5d878f commit 206c410

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

python/test/unit/language/test_block_pointer.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,36 @@
77

88

99
@triton.jit
10-
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr):
10+
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, PADDING_OPTION: tl.constexpr,
11+
TEST_LOWER_BOUND: tl.constexpr, TEST_UPPER_BOUND: tl.constexpr):
1112
pid = tl.program_id(0)
13+
offset = pid * BLOCK_SIZE
14+
if TEST_LOWER_BOUND:
15+
offset = -N
16+
elif TEST_UPPER_BOUND:
17+
offset = N
1218
# We only copy half of the data to see if the padding works
13-
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
19+
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(offset, ),
1420
block_shape=(BLOCK_SIZE, ), order=(0, ))
15-
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
21+
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(offset, ),
1622
block_shape=(BLOCK_SIZE, ), order=(0, ))
17-
if padding_option is None:
23+
if PADDING_OPTION is None:
1824
a = tl.load(a_block_ptr, boundary_check=(0, ))
1925
else:
20-
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
26+
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=PADDING_OPTION)
2127
tl.store(b_block_ptr, a, boundary_check=(0, ))
2228

2329

2430
@pytest.mark.interpreter
25-
@pytest.mark.parametrize("dtypes_str, n, padding_option", [ #
26-
(dtypes_str, n, padding)
31+
@pytest.mark.parametrize("dtypes_str, n, padding_option, boundary_check", [ #
32+
(dtypes_str, n, padding, boundary_check) #
2733
for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"),
2834
("float32", "float32"), ("bfloat16", "bfloat16"))
2935
for n in (64, 128, 256, 512, 1024)
3036
for padding in (None, "zero", "nan") #
37+
for boundary_check in (None, "lower", "upper")
3138
])
32-
def test_block_copy(dtypes_str, n, padding_option, device):
39+
def test_block_copy(dtypes_str, n, padding_option, boundary_check, device):
3340
src_dtype_str = dtypes_str[0]
3441
dst_dtype_str = dtypes_str[1]
3542
src_dtype = getattr(torch, src_dtype_str)
@@ -45,13 +52,17 @@ def test_block_copy(dtypes_str, n, padding_option, device):
4552
b = torch.zeros((n, ), device=device, dtype=dst_dtype)
4653

4754
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), )
48-
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)
55+
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, PADDING_OPTION=padding_option,
56+
TEST_LOWER_BOUND=boundary_check == "lower", TEST_UPPER_BOUND=boundary_check == "upper")
4957
a.to(dst_dtype)
50-
assert torch.all(a[0:n // 2] == b[0:n // 2])
51-
if padding_option == "zero":
52-
assert torch.all(b[n // 2:n] == 0)
53-
elif padding_option == "nan":
54-
assert torch.all(torch.isnan(b[n // 2:n]))
58+
if (boundary_check == "lower") or (boundary_check == "upper"):
59+
assert torch.all(b == 0)
60+
else:
61+
assert torch.all(a[0:n // 2] == b[0:n // 2])
62+
if padding_option == "zero":
63+
assert torch.all(b[n // 2:n] == 0)
64+
elif padding_option == "nan":
65+
assert torch.all(torch.isnan(b[n // 2:n]))
5566

5667

5768
@triton.jit

python/triton/runtime/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def materialize_pointers(self, boundary_check):
6565
off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
6666
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
6767
if dim in boundary_check:
68-
masks = np.logical_and(masks, off < self.shape[dim].data)
68+
masks = masks & (off < self.shape[dim].data) & (off >= 0)
6969
ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
7070
return ptrs, masks
7171

0 commit comments

Comments
 (0)