@@ -1910,24 +1910,33 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
19101910
19111911
19121912@pytest .mark .interpreter
1913- @pytest .mark .parametrize ("sem" , [None , ' acquire' , ' release' , ' acq_rel' , ' relaxed' ])
1913+ @pytest .mark .parametrize ("sem" , [None , " acquire" , " release" , " acq_rel" , " relaxed" ])
19141914@pytest .mark .parametrize ("num_ctas" , num_ctas_list )
1915- def test_tensor_atomic_cas (sem , num_ctas , device ):
1915+ @pytest .mark .parametrize ("size" , [4 , 128 , 512 ])
1916+ @pytest .mark .parametrize ("dtype_str" , ['bfloat16' , 'float16' , 'float32' , 'uint64' , 'int64' , 'float64' ])
1917+ def test_tensor_atomic_cas (sem , size , dtype_str , num_ctas , device ):
1918+ check_type_supported (dtype_str , device )
1919+ if "float" in dtype_str and is_hip ():
1920+ pytest .skip ("HIP does not support atomic cas with float types" )
19161921
19171922 @triton .jit
1918- def change_value (X , BLOCK_SIZE : tl .constexpr , sem : tl .constexpr ):
1923+ def change_value (X , BLOCK_SIZE : tl .constexpr , sem : tl .constexpr , dtype : tl . constexpr ):
19191924 pid = tl .program_id (axis = 0 )
19201925 block_start = pid * BLOCK_SIZE
19211926 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
1922- t1 = tl .full ((BLOCK_SIZE , ), 0 , dtype = tl . int64 )
1923- t2 = tl .full ((BLOCK_SIZE , ), 2 , dtype = tl . int64 )
1927+ t1 = tl .full ((BLOCK_SIZE , ), 0 , dtype = dtype )
1928+ t2 = tl .full ((BLOCK_SIZE , ), 2 , dtype = dtype )
19241929 tl .atomic_cas (X + offsets , t1 , t2 , sem = sem )
19251930
1926- X = torch .tensor ([0 , 1 , 0 , 1 , 0 , 1 , 0 , 1 ], device = device , dtype = torch .int64 )
1927- Y = torch .tensor ([2 , 1 , 2 , 1 , 2 , 1 , 2 , 1 ], device = device , dtype = torch .int64 )
1931+ torch_dtype = getattr (torch , dtype_str )
1932+ X = torch .zeros ((size , ), device = device , dtype = torch_dtype )
1933+ X [1 ::2 ] = 1
1934+ Y = X .clone ()
1935+ Y [0 ::2 ] = 2
19281936
1929- change_value [(2 , )](X , 4 , sem )
1930- assert (torch .equal (X , Y ))
1937+ tl_dtype = getattr (tl , dtype_str )
1938+ change_value [(2 , )](X , BLOCK_SIZE = size // 2 , sem = sem , dtype = tl_dtype )
1939+ assert torch .equal (X , Y )
19311940
19321941
19331942@pytest .mark .interpreter
0 commit comments