|
7 | 7 |
|
8 | 8 |
|
9 | 9 | @triton.jit |
10 | | -def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr): |
| 10 | +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): |
11 | 11 | pid = tl.program_id(0) |
12 | 12 | # We only copy half of the data to see if the padding works |
13 | 13 | a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), |
14 | 14 | block_shape=(BLOCK_SIZE, ), order=(0, )) |
15 | 15 | b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), |
16 | 16 | block_shape=(BLOCK_SIZE, ), order=(0, )) |
17 | | - # if padding_option is None: |
18 | | - a = tl.load(a_block_ptr, boundary_check=(0, )) |
19 | | - # else: |
20 | | - # a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) |
| 17 | + if padding_option is None: |
| 18 | + a = tl.load(a_block_ptr, boundary_check=(0, )) |
| 19 | + else: |
| 20 | + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) |
21 | 21 | tl.store(b_block_ptr, a, boundary_check=(0, )) |
22 | 22 |
|
23 | 23 |
|
24 | 24 | @pytest.mark.interpreter |
25 | | -@pytest.mark.parametrize("dtypes_str, n", [ # |
26 | | - (dtypes_str, n) |
27 | | - # for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), |
28 | | - # ("float32", "float32"), ("bfloat16", "bfloat16")) |
29 | | - for dtypes_str in [("float16", "float16")] |
30 | | - for n in [64] |
| 25 | +@pytest.mark.parametrize("dtypes_str, n, padding_option", [ # |
| 26 | + (dtypes_str, n, padding) |
| 27 | + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), |
| 28 | + ("float32", "float32"), ("bfloat16", "bfloat16")) |
| 29 | + for n in (64, 128, 256, 512, 1024) |
| 30 | + for padding in (None, "zero", "nan") # |
31 | 31 | ]) |
32 | | -def test_block_copy(dtypes_str, n, device): |
| 32 | +def test_block_copy(dtypes_str, n, padding_option, device): |
33 | 33 | src_dtype_str = dtypes_str[0] |
34 | 34 | dst_dtype_str = dtypes_str[1] |
35 | 35 | src_dtype = getattr(torch, src_dtype_str) |
36 | 36 | dst_dtype = getattr(torch, dst_dtype_str) |
37 | 37 | check_type_supported(src_dtype, device) |
38 | 38 | check_type_supported(dst_dtype, device) |
39 | 39 | if src_dtype_str in ("bool", "int16", "int32"): |
40 | | - # if padding_option == "nan": |
41 | | - # pytest.xfail("Padding with NaN is not supported for integer types") |
| 40 | + if padding_option == "nan": |
| 41 | + pytest.xfail("Padding with NaN is not supported for integer types") |
42 | 42 | a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) |
43 | 43 | else: |
44 | 44 | a = torch.randn((n, ), device=device, dtype=src_dtype) |
45 | 45 | b = torch.zeros((n, ), device=device, dtype=dst_dtype) |
46 | 46 |
|
47 | 47 | grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) |
48 | | - block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64) |
| 48 | + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) |
49 | 49 | a.to(dst_dtype) |
50 | 50 | assert torch.all(a[0:n // 2] == b[0:n // 2]) |
51 | | - |
52 | | - |
53 | | -# if padding_option == "zero": |
54 | | -# assert torch.all(b[n // 2:n] == 0) |
55 | | -# elif padding_option == "nan": |
56 | | -# assert torch.all(torch.isnan(b[n // 2:n])) |
| 51 | + if padding_option == "zero": |
| 52 | + assert torch.all(b[n // 2:n] == 0) |
| 53 | + elif padding_option == "nan": |
| 54 | + assert torch.all(torch.isnan(b[n // 2:n])) |
57 | 55 |
|
58 | 56 |
|
59 | 57 | @triton.jit |
|
0 commit comments