@@ -1615,22 +1615,30 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
16151615@pytest .mark .interpreter
16161616@pytest .mark .parametrize ("sem" , [None , 'acquire' , 'release' , 'acq_rel' , 'relaxed' ])
16171617@pytest .mark .parametrize ("num_ctas" , num_ctas_list )
1618- def test_atomic_cas (sem , num_ctas , device ):
1618+ @pytest .mark .parametrize ("dtype_str" , ["int32" , "int64" ])
1619+ def test_atomic_cas (sem , num_ctas , dtype_str , device ):
16191620 # 1. make sure that atomic_cas changes the original value (Lock)
16201621 @triton .jit
1621- def change_value (Lock ):
1622- tl .atomic_cas (Lock , 0 , 1 )
1622+ def change_value (Lock , triton_dtype : tl .constexpr ):
1623+ num0 = tl .full ((1 , ), 0 , dtype = triton_dtype ).item ()
1624+ num1 = tl .full ((1 , ), 1 , dtype = triton_dtype ).item ()
1625+ tl .atomic_cas (Lock , num0 , num1 )
16231626
1624- Lock = torch .zeros ((1 , ), device = device , dtype = torch .int32 )
1625- change_value [(1 , )](Lock )
1627+ torch_dtype = getattr (torch , dtype_str )
1628+ triton_dtype = getattr (tl , dtype_str )
1629+ Lock = torch .zeros ((1 , ), device = device , dtype = torch_dtype )
1630+ change_value [(1 , )](Lock , triton_dtype )
16261631
16271632 assert (Lock [0 ] == 1 )
16281633
16291634 # 2. only one block enters the critical section
16301635 @triton .jit
1631- def serialized_add (data , Lock , SEM : tl .constexpr ):
1636+ def serialized_add (data , Lock , triton_dtype : tl .constexpr , SEM : tl .constexpr ):
1637+ num0 = tl .full ((1 , ), 0 , dtype = triton_dtype ).item ()
1638+ num1 = tl .full ((1 , ), 1 , dtype = triton_dtype ).item ()
1639+
16321640 ptrs = data + tl .arange (0 , 128 )
1633- while tl .atomic_cas (Lock , 0 , 1 , SEM ) == 1 :
1641+ while tl .atomic_cas (Lock , num0 , num1 , SEM ) == 1 :
16341642 pass
16351643
16361644 tl .store (ptrs , tl .load (ptrs ) + 1.0 )
@@ -1640,12 +1648,12 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
16401648 tl .debug_barrier ()
16411649
16421650 # release lock
1643- tl .atomic_xchg (Lock , 0 )
1651+ tl .atomic_xchg (Lock , num0 )
16441652
1645- Lock = torch .zeros ((1 , ), device = device , dtype = torch . int32 )
1653+ Lock = torch .zeros ((1 , ), device = device , dtype = torch_dtype )
16461654 data = torch .zeros ((128 , ), device = device , dtype = torch .float32 )
16471655 ref = torch .full ((128 , ), 2000.0 )
1648- h = serialized_add [(2000 , )](data , Lock , SEM = sem , num_ctas = num_ctas )
1656+ h = serialized_add [(2000 , )](data , Lock , triton_dtype = triton_dtype , SEM = sem , num_ctas = num_ctas )
16491657 sem_str = "acq_rel" if sem is None else sem
16501658 np .testing .assert_allclose (to_numpy (data ), to_numpy (ref ))
16511659 if not is_cuda ():
0 commit comments