55import triton
66import triton .language as tl
77
8- from triton ._internal_testing import is_cuda , is_ampere_or_newer , is_hip_cdna3 , is_hip_cdna4 , is_hopper_or_newer , is_hopper
8+ from triton ._internal_testing import is_ampere_or_newer , is_hip_cdna3 , is_hip_cdna4 , is_hopper_or_newer , is_hopper
99from triton .experimental import gluon
1010from triton .experimental .gluon import language as ttgl
1111from triton .experimental .gluon .language .nvidia .ampere import async_copy , mbarrier
1414from triton .experimental .gluon .language .amd .cdna4 import async_copy as cdna4_async_copy
1515from triton .experimental .gluon .language .extra import libdevice
1616
17+ THREADS_PER_WARP = triton .runtime .driver .active .get_current_target ().warp_size
18+
1719
1820@gluon .jit
1921def copy_kernel (Out , In , numel , XBLOCK : ttgl .constexpr , layout : ttgl .constexpr ):
@@ -24,18 +26,15 @@ def copy_kernel(Out, In, numel, XBLOCK: ttgl.constexpr, layout: ttgl.constexpr):
2426 ttgl .store (Out + xoffset , data , xmask )
2527
2628
27- copy_kernel_tpw = [32 ] if is_cuda () else [64 ]
28-
29-
3029@pytest .mark .parametrize ("layout" , [
31- ttgl .BlockedLayout (size_per_thread = [1 ], threads_per_warp = copy_kernel_tpw , warps_per_cta = [4 ], order = [0 ]),
32- ttgl .BlockedLayout (size_per_thread = [2 ], threads_per_warp = copy_kernel_tpw , warps_per_cta = [4 ], order = [0 ]),
33- ttgl .BlockedLayout (size_per_thread = [4 ], threads_per_warp = copy_kernel_tpw , warps_per_cta = [4 ], order = [0 ]),
34- ttgl .BlockedLayout (size_per_thread = [8 ], threads_per_warp = copy_kernel_tpw , warps_per_cta = [4 ], order = [0 ]),
35- ttgl .BlockedLayout (size_per_thread = [1 ], threads_per_warp = copy_kernel_tpw , warps_per_cta = [8 ], order = [0 ]),
36- ttgl .BlockedLayout (size_per_thread = [2 ], threads_per_warp = copy_kernel_tpw , warps_per_cta = [8 ], order = [0 ]),
37- ttgl .BlockedLayout (size_per_thread = [4 ], threads_per_warp = copy_kernel_tpw , warps_per_cta = [8 ], order = [0 ]),
38- ttgl .BlockedLayout (size_per_thread = [8 ], threads_per_warp = copy_kernel_tpw , warps_per_cta = [8 ], order = [0 ]),
30+ ttgl .BlockedLayout (size_per_thread = [1 ], threads_per_warp = [ THREADS_PER_WARP ] , warps_per_cta = [4 ], order = [0 ]),
31+ ttgl .BlockedLayout (size_per_thread = [2 ], threads_per_warp = [ THREADS_PER_WARP ] , warps_per_cta = [4 ], order = [0 ]),
32+ ttgl .BlockedLayout (size_per_thread = [4 ], threads_per_warp = [ THREADS_PER_WARP ] , warps_per_cta = [4 ], order = [0 ]),
33+ ttgl .BlockedLayout (size_per_thread = [8 ], threads_per_warp = [ THREADS_PER_WARP ] , warps_per_cta = [4 ], order = [0 ]),
34+ ttgl .BlockedLayout (size_per_thread = [1 ], threads_per_warp = [ THREADS_PER_WARP ] , warps_per_cta = [8 ], order = [0 ]),
35+ ttgl .BlockedLayout (size_per_thread = [2 ], threads_per_warp = [ THREADS_PER_WARP ] , warps_per_cta = [8 ], order = [0 ]),
36+ ttgl .BlockedLayout (size_per_thread = [4 ], threads_per_warp = [ THREADS_PER_WARP ] , warps_per_cta = [8 ], order = [0 ]),
37+ ttgl .BlockedLayout (size_per_thread = [8 ], threads_per_warp = [ THREADS_PER_WARP ] , warps_per_cta = [8 ], order = [0 ]),
3938])
4039@pytest .mark .parametrize ("XBLOCK" , [128 , 256 , 512 , 1024 , 2048 ])
4140def test_copy_kernel (layout , XBLOCK ):
@@ -403,13 +402,12 @@ def fast_expf_kernel(x_ptr, y_ptr, warp_size: ttgl.constexpr, num_warps: ttgl.co
403402 y = libdevice .fast_expf (x )
404403 ttgl .store (y_ptr + offs , y )
405404
406- warp_size = 32 if is_cuda () else 64
407405 num_warps = 4
408406
409407 torch .manual_seed (0 )
410- x = torch .randn (warp_size * num_warps , device = "cuda" , dtype = torch .float32 )
408+ x = torch .randn (THREADS_PER_WARP * num_warps , device = "cuda" , dtype = torch .float32 )
411409 y = torch .empty_like (x )
412- fast_expf_kernel [(1 , )](x , y , warp_size , num_warps )
410+ fast_expf_kernel [(1 , )](x , y , THREADS_PER_WARP , num_warps )
413411 torch .testing .assert_close (y , torch .exp (x ), atol = 1e-5 , rtol = 1e-4 )
414412
415413
@@ -425,13 +423,12 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
425423 z = libdevice .fast_dividef (x , y )
426424 ttgl .store (z_ptr + offs , z )
427425
428- warp_size = 32 if is_cuda () else 64
429426 num_warps = 4
430427
431428 torch .manual_seed (0 )
432- x = torch .randn (warp_size * num_warps , device = "cuda" , dtype = torch .float32 )
429+ x = torch .randn (THREADS_PER_WARP * num_warps , device = "cuda" , dtype = torch .float32 )
433430 y = torch .randn_like (x )
434431 z = torch .empty_like (x )
435432 y [y == 0 ] = 1.0
436- fast_dividef_kernel [(1 , )](x , y , z , warp_size , num_warps )
433+ fast_dividef_kernel [(1 , )](x , y , z , THREADS_PER_WARP , num_warps )
437434 torch .testing .assert_close (z , torch .div (x , y ), atol = 1e-5 , rtol = 1e-4 )
0 commit comments