1+ # flake8: noqa: F821, F841
12import torch
2- aten = torch .ops .aten
3-
4- import pytest
3+ import pytest
54
65import triton
7- import triton .language as tl
6+ import triton .language as tl
7+
8+ aten = torch .ops .aten
89
910
1011def patch_kernel (template , to_replace ):
@@ -13,25 +14,27 @@ def patch_kernel(template, to_replace):
1314 kernel .src = kernel .src .replace (key , value )
1415 return kernel
1516
17+
1618@pytest .mark .parametrize ("float_div" , [True , False ])
1719@pytest .mark .parametrize ("floor" , [True , False ])
1820@pytest .mark .parametrize ("trunc" , [True , False ])
1921def test_divide (float_div , floor , trunc , device ):
20- # regression test for various division cases
22+ # regression test for various division cases
2123
22- @triton .jit
24+ @triton .jit
2325 def divide_kernel (a , b , out_ptr0 , out_ptr1 , out_ptr2 , out_ptr3 , out_ptr4 , xnumel , XBLOCK : tl .constexpr ):
2426 xoffset = tl .program_id (0 ) * XBLOCK
2527 xindex = xoffset + tl .arange (0 , XBLOCK )[:]
2628 xmask = xindex < xnumel
2729 x0 = xindex
2830 tmp0 = tl .load (a + (x0 ), xmask )
2931 tmp2 = tl .load (b + (x0 ), xmask )
30- # custom bits
32+ # custom bits
3133 tmp1 = tmp0 .to (tl .float32 )
3234 tmp3 = tmp2 .to (tl .float32 )
3335 tmp4 = tmp1 / tmp3
34- tmp5 = tl .where ((tmp0 < 0 ) != (tmp2 < 0 ), tl .where (tmp0 % tmp2 != 0 , tmp0 // tmp2 - 1 , tmp0 // tmp2 ), tmp0 // tmp2 )
36+ tmp5 = tl .where ((tmp0 < 0 ) != (tmp2 < 0 ), tl .where (tmp0 % tmp2 != 0 , tmp0 // tmp2 - 1 , tmp0 // tmp2 ),
37+ tmp0 // tmp2 )
3538 tmp6 = tmp0 // tmp2
3639 GENERATE_OUTPUTS_HERE
3740
@@ -41,7 +44,8 @@ def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel
4144 outputs_floor = "\n tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor is True else ""
4245 outputs_trunc = "\n tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc is True else ""
4346
44- divide_kernel = patch_kernel (divide_kernel , {"GENERATE_OUTPUTS_HERE" : f"{ outputs_float_div } \n { outputs_floor } \n { outputs_trunc } " })
47+ divide_kernel = patch_kernel (divide_kernel ,
48+ {"GENERATE_OUTPUTS_HERE" : f"{ outputs_float_div } \n { outputs_floor } \n { outputs_trunc } " })
4549
4650 def launch_triton (a , b ):
4751 output0 = torch .zeros_like (a )
@@ -57,15 +61,15 @@ def launch_triton(a, b):
5761 divide_kernel [grid ](a , b , output0 , output1 , output2 , output3 , output4 , n_elements , XBLOCK = 128 )
5862
5963 return (output0 , output1 , output2 , output3 , output4 )
60-
64+
6165 def launch_torch (a , b ):
62- return (
63- aten .div (a , b , rounding_mode = None ) if float_div is True else torch .zeros_like (a ),
64- aten .div (a , b , rounding_mode = "floor" ) if floor is True else torch .zeros_like (a ),
65- aten .div (a , b , rounding_mode = "trunc" ) if trunc is True else torch .zeros_like (a ),
66- a / b if float_div is True else torch .zeros_like (a ),
67- a // b if floor is True else torch .zeros_like (a ),
68- )
66+ return (
67+ aten .div (a , b , rounding_mode = None ) if float_div is True else torch .zeros_like (a ),
68+ aten .div (a , b , rounding_mode = "floor" ) if floor is True else torch .zeros_like (a ),
69+ aten .div (a , b , rounding_mode = "trunc" ) if trunc is True else torch .zeros_like (a ),
70+ a / b if float_div is True else torch .zeros_like (a ),
71+ a // b if floor is True else torch .zeros_like (a ),
72+ )
6973
7074 a = torch .randint (2 ** 32 , 2 ** 40 , [100 , 100 ], device = device )
7175 b = torch .randint (- 10 , - 1 , [100 , 100 ], device = device )
@@ -75,6 +79,6 @@ def launch_torch(a, b):
7579 torch_result = launch_torch (a , b )
7680
7781 for i in range (5 ):
78- torch .testing .assert_close (triton_result [ i ], torch_result [ i ], check_dtype = False , msg = lambda msg : f"Float: { float_div } , Floor: { floor } , Trunc: { trunc } \n Iteration { iter } , { i } failed \n { msg } " )
79-
80-
82+ torch .testing .assert_close (
83+ triton_result [ i ], torch_result [ i ], check_dtype = False , msg = lambda msg :
84+ f"Float: { float_div } , Floor: { floor } , Trunc: { trunc } \n Iteration { iter } , { i } failed \n { msg } " )
0 commit comments