@@ -1564,7 +1564,6 @@ def kernel(X, Y, Z):
15641564@pytest .mark .parametrize (
15651565 "op, dtype_x_str, mode, sem" ,
15661566 itertools .chain .from_iterable ([[
1567- ('add' , 'bfloat16' , mode , sem ),
15681567 ('add' , 'float16' , mode , sem ),
15691568 ('add' , 'uint32' , mode , sem ),
15701569 ('add' , 'int32' , mode , sem ),
@@ -1590,8 +1589,8 @@ def kernel(X, Y, Z):
15901589def test_atomic_rmw (op , dtype_x_str , mode , sem , device ):
15911590 check_type_supported (dtype_x_str , device )
15921591 if is_interpreter ():
1593- if dtype_x_str == 'float16' or dtype_x_str == 'bfloat16' :
1594- pytest .xfail ("Only test atomic bfloat16/ float16 ops on GPU" )
1592+ if dtype_x_str == 'float16' :
1593+ pytest .xfail ("Only test atomic float16 ops on GPU" )
15951594
15961595 n_programs = 5
15971596
@@ -1606,14 +1605,12 @@ def kernel(X, Z):
16061605 sem_arg = sem if sem is None else f'"{ sem } "'
16071606 kernel = patch_kernel (kernel , {'GENERATE_TEST_HERE' : f'tl.atomic_{ op } (Z, x, sem={ sem_arg } )' })
16081607 numpy_op = {'add' : np .sum , 'max' : np .max , 'min' : np .min }[op ]
1609- max_neutral = float ('-inf' ) if dtype_x_str in float_dtypes_with_bfloat16 else np .iinfo (getattr (np , dtype_x_str )).min
1610- min_neutral = float ('inf' ) if dtype_x_str in float_dtypes_with_bfloat16 else np .iinfo (getattr (np , dtype_x_str )).max
1608+ max_neutral = float ('-inf' ) if dtype_x_str in float_dtypes else np .iinfo (getattr (np , dtype_x_str )).min
1609+ min_neutral = float ('inf' ) if dtype_x_str in float_dtypes else np .iinfo (getattr (np , dtype_x_str )).max
16111610 neutral = {'add' : 0 , 'max' : max_neutral , 'min' : min_neutral }[op ]
16121611
16131612 # triton result
16141613 rs = RandomState (17 )
1615- dst_type = 'bfloat16' if (dtype_x_str == 'bfloat16' ) else None
1616- dtype_x_str = 'float32' if (dtype_x_str == 'bfloat16' ) else dtype_x_str
16171614 x = np .array ([2 ** i for i in range (n_programs )], dtype = getattr (np , dtype_x_str ))
16181615 if mode == 'all_neg' :
16191616 x = - np .abs (x )
@@ -1625,17 +1622,12 @@ def kernel(X, Z):
16251622 if mode == 'max_pos' :
16261623 idx = rs .randint (n_programs , size = (1 , )).item ()
16271624 x [idx ] = np .max (np .abs (x )) + 1
1628- x_tri = to_triton (x , device = device , dst_type = dst_type )
1625+ x_tri = to_triton (x , device = device )
16291626
1630- z_tri = to_triton (np .array ([neutral ], dtype = getattr (np , dtype_x_str )), device = device , dst_type = dst_type )
1627+ z_tri = to_triton (np .array ([neutral ], dtype = getattr (np , dtype_x_str )), device = device )
16311628 h = kernel [(n_programs , )](x_tri , z_tri )
16321629 # torch result
1633- if dst_type == 'bfloat16' :
1634- z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
1635- # trunc mantissa for a fair comparison of accuracy
1636- z_ref = (z_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1637- else :
1638- z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
1630+ z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
16391631 # compare
16401632 exact = op not in ['add' ]
16411633 if exact :
@@ -1646,12 +1638,6 @@ def kernel(X, Z):
16461638 if not is_cuda ():
16471639 return
16481640
1649- # atom.add.bf16 is unsupported prior to Hopper so instead we generate an
1650- # atom.cas add loop on Ampere and prior
1651- if dst_type == 'bfloat16' and torch .cuda .get_device_capability ()[0 ] < 9 :
1652- assert f"atom.{ sem_str } .global.cas" in h .asm ["ptx" ]
1653- return
1654-
16551641 assert f"atom.global.gpu.{ sem_str } " in h .asm ["ptx" ]
16561642
16571643
@@ -1676,7 +1662,7 @@ def kernel(X):
16761662 for shape in [(2 , 2 ), (2 , 8 ), (8 , 2 ), (8 , 8 ), (32 , 32 ), (64 , 64 )]
16771663 for axis in [0 , 1 ]
16781664 for num_ctas in num_ctas_list
1679- for dtype_x_str in ['bfloat16' , ' float16' , 'float32' , 'uint64' , 'int64' , 'float64' ]
1665+ for dtype_x_str in ['float16' , 'float32' , 'uint64' , 'int64' , 'float64' ]
16801666 for check_return_val in ([True , False ] if is_hip () else [True ])])
16811667def test_tensor_atomic_rmw (shape , axis , num_ctas , dtype_x_str , check_return_val , device ):
16821668 check_type_supported (dtype_x_str , device )
@@ -1690,14 +1676,14 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
16901676 off1 = tl .arange (0 , SHAPE1 )
16911677 x = tl .load (X + off0 [:, None ] * SHAPE1 + off1 [None , :])
16921678
1693- if DTYPE == tl .float16 or DTYPE == tl . bfloat16 :
1679+ if DTYPE == tl .float16 :
16941680 # sum can have bad numerics when accumulating in float16.
16951681 # if we're dealing with float16, do the sum in float32.
16961682 x = x .to (tl .float32 )
16971683
16981684 z = tl .sum (x , axis = AXIS )
16991685
1700- if DTYPE == tl .float16 or DTYPE == tl . bfloat16 :
1686+ if DTYPE == tl .float16 :
17011687 z = z .to (DTYPE )
17021688
17031689 if AXIS == 1 :
@@ -1713,7 +1699,7 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
17131699 x = numpy_random ((shape0 , shape1 ), dtype_str = dtype_x_str , rs = rs )
17141700 z_shape = (shape0 , ) if axis == 1 else (shape1 , )
17151701 z = numpy_random (z_shape , dtype_str = dtype_x_str , rs = rs )
1716- old = np .zeros (z_shape , dtype = z . dtype )
1702+ old = np .zeros (z_shape , dtype = getattr ( np , dtype_x_str ) )
17171703 # reference results
17181704 if x .dtype == np .float16 :
17191705 # do the sum in float32 to reduce numerical variation
@@ -1722,31 +1708,17 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
17221708 z_ref = z + np .sum (x , axis = axis , keepdims = False )
17231709 old_ref = np .copy (z )
17241710 # triton result
1725- x_tri = to_triton (x , device = device , dst_type = dtype_x_str )
1726- z_tri = to_triton (z , device = device , dst_type = dtype_x_str )
1727- old_tri = to_triton (old , device = device , dst_type = dtype_x_str )
1711+ x_tri = to_triton (x , device = device )
1712+ z_tri = to_triton (z , device = device )
1713+ old_tri = to_triton (old , device = device )
17281714
17291715 def torch_to_triton_dtype (t ):
1730- if t == torch .bfloat16 :
1731- return tl .bfloat16
17321716 if t == torch .float16 :
17331717 return tl .float16
17341718 return None
17351719
17361720 kernel [(1 , )](z_tri , x_tri , old_tri , axis , shape0 , shape1 , torch_to_triton_dtype (x_tri .dtype ), check_return_val ,
17371721 num_ctas = num_ctas )
1738-
1739- if dtype_x_str == 'bfloat16' :
1740- # trunc mantissa for a fair comparison of accuracy
1741- z_ref = (z_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1742- old_ref = (old_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1743- # mantissa trunc is not enough, bump up the relative tolerance as well
1744- np .testing .assert_allclose (z_ref , to_numpy (z_tri ), rtol = 0.5 )
1745- # check return vals, but use assert_allclose for bf16
1746- if check_return_val :
1747- np .testing .assert_allclose (old_ref , to_numpy (old_tri ), rtol = 0.5 )
1748- return
1749-
17501722 np .testing .assert_allclose (z_ref , to_numpy (z_tri ), rtol = 1e-4 )
17511723 if check_return_val :
17521724 np .testing .assert_equal (old_ref , to_numpy (old_tri ))
@@ -1756,9 +1728,8 @@ def torch_to_triton_dtype(t):
17561728@pytest .mark .parametrize ("size, num_ctas, dtype_x_str" , [(size , num_ctas , dtype_x_str )
17571729 for size in [2 , 4 , 8 , 32 , 64 , 128 ]
17581730 for num_ctas in num_ctas_list
1759- for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
1731+ for dtype_x_str in ['float16' , 'float32' ]])
17601732def test_tensor_atomic_add_non_exclusive_offset (size , num_ctas , dtype_x_str , device ):
1761- check_type_supported (dtype_x_str , device )
17621733
17631734 @triton .jit
17641735 def kernel (X , val , NUM : tl .constexpr ):
@@ -1768,9 +1739,8 @@ def kernel(X, val, NUM: tl.constexpr):
17681739 tl .atomic_add (X + offset // 2 , val )
17691740
17701741 shape = (size // 2 , size )
1771- dtype = getattr (torch , dtype_x_str )
1772- x = torch .zeros (shape , dtype = dtype , device = device )
1773- val = torch .randn ((size ** 2 ), dtype = dtype , device = device )
1742+ x = torch .zeros (shape , dtype = getattr (torch , dtype_x_str ), device = device )
1743+ val = torch .randn ((size ** 2 ), dtype = getattr (torch , dtype_x_str ), device = device )
17741744 kernel [(1 , )](x , val , size , num_warps = 1 , num_ctas = num_ctas )
17751745 ref = val [0 ::2 ] + val [1 ::2 ]
17761746 torch .testing .assert_close (ref , x .reshape (math .prod (shape )))
@@ -1783,7 +1753,7 @@ def kernel(X, val, NUM: tl.constexpr):
17831753 for idx_order in ['increase' , 'decrease' , 'random_no_duplication' , 'random' ]
17841754 for mask_step in range (1 , 5 )
17851755 for num_ctas in num_ctas_list
1786- for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
1756+ for dtype_x_str in ['float16' , 'float32' ]])
17871757def test_tensor_atomic_add_access_patterns (shape , idx_order , mask_step , num_ctas , dtype_x_str , device ):
17881758 check_type_supported (dtype_x_str , device )
17891759 if is_interpreter ():
@@ -1811,9 +1781,8 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
18111781 if idx_order == 'random' :
18121782 idx = torch .randint (0 , shape1 , size = (shape0 , shape1 ), device = device )
18131783
1814- dtype = getattr (torch , dtype_x_str )
1815- val = torch .randn ((shape0 , shape1 ), dtype = dtype , device = device )
1816- dst = torch .randn ((shape0 , shape1 ), dtype = dtype , device = device )
1784+ val = torch .randn ((shape0 , shape1 ), dtype = getattr (torch , dtype_x_str ), device = device )
1785+ dst = torch .randn ((shape0 , shape1 ), dtype = getattr (torch , dtype_x_str ), device = device )
18171786
18181787 dst_ref = dst .clone ()
18191788
@@ -1825,11 +1794,6 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
18251794 cnt += 1
18261795
18271796 kernel [(1 , )](val , idx , dst , shape0 , shape1 , mask_step , 64 , num_ctas = num_ctas )
1828-
1829- if dtype_x_str == 'bfloat16' :
1830- torch .testing .assert_close (dst_ref , dst , rtol = 0.1 , atol = 0.1 )
1831- return
1832-
18331797 np .testing .assert_allclose (to_numpy (dst_ref ), to_numpy (dst ), atol = 1e-2 )
18341798
18351799
0 commit comments