77
88
99@triton .jit
10- def block_copy_kernel (a_ptr , b_ptr , N , BLOCK_SIZE : tl .constexpr , padding_option : tl .constexpr ):
10+ def block_copy_kernel (a_ptr , b_ptr , N , BLOCK_SIZE : tl .constexpr , PADDING_OPTION : tl .constexpr ,
11+ TEST_LOWER_BOUND : tl .constexpr , TEST_UPPER_BOUND : tl .constexpr ):
1112 pid = tl .program_id (0 )
13+ offset = pid * BLOCK_SIZE
14+ if TEST_LOWER_BOUND :
15+ offset = - N
16+ elif TEST_UPPER_BOUND :
17+ offset = N
1218 # We only copy half of the data to see if the padding works
13- a_block_ptr = tl .make_block_ptr (base = a_ptr , shape = (N // 2 , ), strides = (1 , ), offsets = (pid * BLOCK_SIZE , ),
19+ a_block_ptr = tl .make_block_ptr (base = a_ptr , shape = (N // 2 , ), strides = (1 , ), offsets = (offset , ),
1420 block_shape = (BLOCK_SIZE , ), order = (0 , ))
15- b_block_ptr = tl .make_block_ptr (base = b_ptr , shape = (N , ), strides = (1 , ), offsets = (pid * BLOCK_SIZE , ),
21+ b_block_ptr = tl .make_block_ptr (base = b_ptr , shape = (N , ), strides = (1 , ), offsets = (offset , ),
1622 block_shape = (BLOCK_SIZE , ), order = (0 , ))
17- if padding_option is None :
23+ if PADDING_OPTION is None :
1824 a = tl .load (a_block_ptr , boundary_check = (0 , ))
1925 else :
20- a = tl .load (a_block_ptr , boundary_check = (0 , ), padding_option = padding_option )
26+ a = tl .load (a_block_ptr , boundary_check = (0 , ), padding_option = PADDING_OPTION )
2127 tl .store (b_block_ptr , a , boundary_check = (0 , ))
2228
2329
2430@pytest .mark .interpreter
25- @pytest .mark .parametrize ("dtypes_str, n, padding_option" , [ #
26- (dtypes_str , n , padding )
31+ @pytest .mark .parametrize ("dtypes_str, n, padding_option, boundary_check " , [ #
32+ (dtypes_str , n , padding , boundary_check ) #
2733 for dtypes_str in (("bool" , "bool" ), ("int16" , "int16" ), ("int32" , "int32" ), ("float16" , "float16" ),
2834 ("float32" , "float32" ), ("bfloat16" , "bfloat16" ))
2935 for n in (64 , 128 , 256 , 512 , 1024 )
3036 for padding in (None , "zero" , "nan" ) #
37+ for boundary_check in (None , "lower" , "upper" )
3138])
32- def test_block_copy (dtypes_str , n , padding_option , device ):
39+ def test_block_copy (dtypes_str , n , padding_option , boundary_check , device ):
3340 src_dtype_str = dtypes_str [0 ]
3441 dst_dtype_str = dtypes_str [1 ]
3542 src_dtype = getattr (torch , src_dtype_str )
@@ -45,13 +52,17 @@ def test_block_copy(dtypes_str, n, padding_option, device):
4552 b = torch .zeros ((n , ), device = device , dtype = dst_dtype )
4653
4754 grid = lambda meta : (triton .cdiv (n , meta ["BLOCK_SIZE" ]), )
48- block_copy_kernel [grid ](a_ptr = a , b_ptr = b , N = n , BLOCK_SIZE = 64 , padding_option = padding_option )
55+ block_copy_kernel [grid ](a_ptr = a , b_ptr = b , N = n , BLOCK_SIZE = 64 , PADDING_OPTION = padding_option ,
56+ TEST_LOWER_BOUND = boundary_check == "lower" , TEST_UPPER_BOUND = boundary_check == "upper" )
4957 a .to (dst_dtype )
50- assert torch .all (a [0 :n // 2 ] == b [0 :n // 2 ])
51- if padding_option == "zero" :
52- assert torch .all (b [n // 2 :n ] == 0 )
53- elif padding_option == "nan" :
54- assert torch .all (torch .isnan (b [n // 2 :n ]))
58+ if (boundary_check == "lower" ) or (boundary_check == "upper" ):
59+ assert torch .all (b == 0 )
60+ else :
61+ assert torch .all (a [0 :n // 2 ] == b [0 :n // 2 ])
62+ if padding_option == "zero" :
63+ assert torch .all (b [n // 2 :n ] == 0 )
64+ elif padding_option == "nan" :
65+ assert torch .all (torch .isnan (b [n // 2 :n ]))
5566
5667
5768@triton .jit
0 commit comments