@@ -12,8 +12,11 @@ def patch_kernel(template, to_replace):
1212 for key , value in to_replace .items ():
1313 kernel .src = kernel .src .replace (key , value )
1414 return kernel
15-
16- def test_divide (device ):
15+
16+ @pytest .mark .parametrize ("float_div" , [True , False ])
17+ @pytest .mark .parametrize ("floor" , [True , False ])
18+ @pytest .mark .parametrize ("trunc" , [True , False ])
19+ def test_divide (float_div , floor , trunc , device ):
1720 # regression test for various division cases
1821
1922 @triton .jit
@@ -30,36 +33,38 @@ def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel
3033 tmp4 = tmp1 / tmp3
3134 tmp5 = tl .where ((tmp0 < 0 ) != (tmp2 < 0 ), tl .where (tmp0 % tmp2 != 0 , tmp0 // tmp2 - 1 , tmp0 // tmp2 ), tmp0 // tmp2 )
3235 tmp6 = tmp0 // tmp2
33- tl .store (out_ptr0 + (x0 ), tmp4 , xmask )
34- tl .store (out_ptr1 + (x0 ), tmp5 , xmask )
35- tl .store (out_ptr2 + (x0 ), tmp6 , xmask )
36- tl .store (out_ptr3 + (x0 ), tmp4 , xmask )
37- tl .store (out_ptr4 + (x0 ), tmp5 , xmask )
36+ GENERATE_OUTPUTS_HERE
3837
3938 torch .manual_seed (0 )
4039
40+ outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div is True else ""
41+ outputs_floor = "\n tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor is True else ""
42+ outputs_trunc = "\n tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc is True else ""
43+
44+ divide_kernel = patch_kernel (divide_kernel , {"GENERATE_OUTPUTS_HERE" : f"{ outputs_float_div } \n { outputs_floor } \n { outputs_trunc } " })
45+
4146 def launch_triton (a , b ):
42- output0 = torch .empty_like (a )
43- output1 = torch .empty_like (a )
44- output2 = torch .empty_like (a )
45- output3 = torch .empty_like (a )
46- output4 = torch .empty_like (a )
47+ output0 = torch .zeros_like (a )
48+ output1 = torch .zeros_like (a )
49+ output2 = torch .zeros_like (a )
50+ output3 = torch .zeros_like (a )
51+ output4 = torch .zeros_like (a )
4752
4853 n_elements = output0 .numel ()
4954
5055 grid = lambda meta : (triton .cdiv (n_elements , meta ['XBLOCK' ]), )
51-
56+
5257 divide_kernel [grid ](a , b , output0 , output1 , output2 , output3 , output4 , n_elements , XBLOCK = 128 )
5358
5459 return (output0 , output1 , output2 , output3 , output4 )
5560
5661 def launch_torch (a , b ):
5762 return (
58- aten .div (a , b , rounding_mode = None ),
59- aten .div (a , b , rounding_mode = "floor" ),
60- aten .div (a , b , rounding_mode = "trunc" ),
61- a / b ,
62- a // b ,
63+ aten .div (a , b , rounding_mode = None ) if float_div is True else torch . zeros_like ( a ) ,
64+ aten .div (a , b , rounding_mode = "floor" ) if floor is True else torch . zeros_like ( a ) ,
65+ aten .div (a , b , rounding_mode = "trunc" ) if trunc is True else torch . zeros_like ( a ) ,
66+ a / b if float_div is True else torch . zeros_like ( a ) ,
67+ a // b if floor is True else torch . zeros_like ( a ) ,
6368 )
6469
6570 a = torch .randint (2 ** 32 , 2 ** 40 , [100 , 100 ], device = device )
@@ -70,6 +75,6 @@ def launch_torch(a, b):
7075 torch_result = launch_torch (a , b )
7176
7277 for i in range (5 ):
73- torch .testing .assert_close (triton_result [i ], torch_result [i ], check_dtype = False , msg = lambda msg : f"Iteration { iter } , { i } failed\n { msg } " )
78+ torch .testing .assert_close (triton_result [i ], torch_result [i ], check_dtype = False , msg = lambda msg : f"Float: { float_div } , Floor: { floor } , Trunc: { trunc } \n Iteration { iter } , { i } failed\n { msg } " )
7479
7580
0 commit comments