Skip to content

Commit 0b21a82

Browse files
committed
Address code review comments
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 17f5b25 commit 0b21a82

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

python/test/unit/language/test_block_pointer.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,53 +7,51 @@
77

88

99
@triton.jit
10-
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr):
10+
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr):
1111
pid = tl.program_id(0)
1212
# We only copy half of the data to see if the padding works
1313
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
1414
block_shape=(BLOCK_SIZE, ), order=(0, ))
1515
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
1616
block_shape=(BLOCK_SIZE, ), order=(0, ))
17-
# if padding_option is None:
18-
a = tl.load(a_block_ptr, boundary_check=(0, ))
19-
# else:
20-
# a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
17+
if padding_option is None:
18+
a = tl.load(a_block_ptr, boundary_check=(0, ))
19+
else:
20+
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
2121
tl.store(b_block_ptr, a, boundary_check=(0, ))
2222

2323

2424
@pytest.mark.interpreter
25-
@pytest.mark.parametrize("dtypes_str, n", [ #
26-
(dtypes_str, n)
27-
# for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"),
28-
# ("float32", "float32"), ("bfloat16", "bfloat16"))
29-
for dtypes_str in [("float16", "float16")]
30-
for n in [64]
25+
@pytest.mark.parametrize("dtypes_str, n, padding_option", [ #
26+
(dtypes_str, n, padding)
27+
for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"),
28+
("float32", "float32"), ("bfloat16", "bfloat16"))
29+
for n in (64, 128, 256, 512, 1024)
30+
for padding in (None, "zero", "nan") #
3131
])
32-
def test_block_copy(dtypes_str, n, device):
32+
def test_block_copy(dtypes_str, n, padding_option, device):
3333
src_dtype_str = dtypes_str[0]
3434
dst_dtype_str = dtypes_str[1]
3535
src_dtype = getattr(torch, src_dtype_str)
3636
dst_dtype = getattr(torch, dst_dtype_str)
3737
check_type_supported(src_dtype, device)
3838
check_type_supported(dst_dtype, device)
3939
if src_dtype_str in ("bool", "int16", "int32"):
40-
# if padding_option == "nan":
41-
# pytest.xfail("Padding with NaN is not supported for integer types")
40+
if padding_option == "nan":
41+
pytest.xfail("Padding with NaN is not supported for integer types")
4242
a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype)
4343
else:
4444
a = torch.randn((n, ), device=device, dtype=src_dtype)
4545
b = torch.zeros((n, ), device=device, dtype=dst_dtype)
4646

4747
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), )
48-
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64)
48+
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)
4949
a.to(dst_dtype)
5050
assert torch.all(a[0:n // 2] == b[0:n // 2])
51-
52-
53-
# if padding_option == "zero":
54-
# assert torch.all(b[n // 2:n] == 0)
55-
# elif padding_option == "nan":
56-
# assert torch.all(torch.isnan(b[n // 2:n]))
51+
if padding_option == "zero":
52+
assert torch.all(b[n // 2:n] == 0)
53+
elif padding_option == "nan":
54+
assert torch.all(torch.isnan(b[n // 2:n]))
5755

5856

5957
@triton.jit

0 commit comments

Comments
 (0)