11import torch
22import pytest
33
4- from triton ._internal_testing import is_cuda
4+ from triton ._internal_testing import is_ampere_or_newer , is_hopper
55from triton .experimental import gluon
66from triton .experimental .gluon import language as ttgl
7+ from triton .experimental .gluon .language .nvidia .ampere import async_copy , mbarrier
78from triton .experimental .gluon .language .nvidia .hopper import tma
89
910
@@ -45,7 +46,7 @@ def tma_kernel(desc):
4546 alloc ._keep_alive ()
4647
4748
48- @pytest .mark .skipif (not is_cuda () or torch . cuda . get_device_capability ()[ 0 ] < 9 , reason = "Requires Hopper" )
49+ @pytest .mark .skipif (not is_hopper () , reason = "Requires Hopper" )
4950def test_tma ():
5051 out = torch .ones ((16 , 16 ), dtype = torch .float16 , device = "cuda" )
5152 layout = ttgl .NVMMASharedLayout (
@@ -59,3 +60,36 @@ def test_tma():
5960 desc = gluon .nvidia .hopper .TensorDescriptor .from_tensor (out , [16 , 16 ], layout )
6061 tma_kernel [(1 , )](desc )
6162 torch .testing .assert_close (out , torch .zeros_like (out ))
63+
64+
65+ @gluon .jit
66+ def async_copy_mbarrier_kernel (out , inp , xnumel , XBLOCK : ttgl .constexpr , YBLOCK : ttgl .constexpr ):
67+ smem = ttgl .allocate_shared_memory (inp .dtype .element_ty , [XBLOCK , YBLOCK ],
68+ ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [1 , 0 ]))
69+ block_layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 4 ], [1 , 32 ], [4 , 1 ], [1 , 0 ])
70+ xindex = ttgl .arange (0 , XBLOCK , ttgl .SliceLayout (1 , block_layout ))[:, None ]
71+ yindex = ttgl .arange (0 , YBLOCK , ttgl .SliceLayout (0 , block_layout ))[None , :]
72+ mask = xindex < xnumel
73+ async_copy .async_copy_global_to_shared (
74+ smem ,
75+ inp + xindex * YBLOCK + yindex ,
76+ mask ,
77+ )
78+ mbar = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
79+ mbarrier .init (mbar , count = 1 )
80+ async_copy .mbarrier_arrive (mbar )
81+ mbarrier .arrive (mbar )
82+ mbarrier .wait (mbar , 0 )
83+
84+ val = smem .load (block_layout )
85+ ttgl .store (out + xindex * YBLOCK + yindex , val )
86+
87+
88+ @pytest .mark .skipif (not is_ampere_or_newer (), reason = "Requires Ampere" )
89+ def test_async_copy_mbarrier ():
90+ tensor_opts = dict (dtype = torch .float , device = "cuda" )
91+ out = torch .empty ((32 , 32 ), ** tensor_opts )
92+ inp = torch .randn ((20 , 32 ), ** tensor_opts )
93+ async_copy_mbarrier_kernel [(1 , )](out , inp , inp .shape [0 ], XBLOCK = 32 , YBLOCK = 32 )
94+ torch .testing .assert_close (out [:20 ], inp )
95+ torch .testing .assert_close (out [20 :], torch .zeros ((12 , 32 ), ** tensor_opts ))
0 commit comments