@@ -1580,6 +1580,7 @@ def kernel(X, Y, Z):
15801580@pytest .mark .parametrize (
15811581 "op, dtype_x_str, mode, sem" ,
15821582 itertools .chain .from_iterable ([[
1583+ ('add' , 'bfloat16' , mode , sem ),
15831584 ('add' , 'float16' , mode , sem ),
15841585 ('add' , 'uint32' , mode , sem ),
15851586 ('add' , 'int32' , mode , sem ),
@@ -1609,6 +1610,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
16091610 pytest .xfail ("Only test atomic bfloat16/float16 ops on GPU" )
16101611 if "uint" in dtype_x_str and mode in ["min_neg" , "all_neg" ]:
16111612 pytest .xfail ("uint cannot be negative" )
1613+ if is_xpu () and dtype_x_str == 'bfloat16' :
1614+ pytest .skip ("bfloat16 not yet supported for xpu" )
16121615
16131616 n_programs = 5
16141617
@@ -1623,12 +1626,14 @@ def kernel(X, Z):
16231626 sem_arg = sem if sem is None else f'"{ sem } "'
16241627 kernel = patch_kernel (kernel , {'GENERATE_TEST_HERE' : f'tl.atomic_{ op } (Z, x, sem={ sem_arg } )' })
16251628 numpy_op = {'add' : np .sum , 'max' : np .max , 'min' : np .min }[op ]
1626- max_neutral = float ('-inf' ) if dtype_x_str in float_dtypes else np .iinfo (getattr (np , dtype_x_str )).min
1627- min_neutral = float ('inf' ) if dtype_x_str in float_dtypes else np .iinfo (getattr (np , dtype_x_str )).max
1629+ max_neutral = float ('-inf' ) if dtype_x_str in float_dtypes_with_bfloat16 else np .iinfo (getattr (np , dtype_x_str )).min
1630+ min_neutral = float ('inf' ) if dtype_x_str in float_dtypes_with_bfloat16 else np .iinfo (getattr (np , dtype_x_str )).max
16281631 neutral = {'add' : 0 , 'max' : max_neutral , 'min' : min_neutral }[op ]
16291632
16301633 # triton result
16311634 rs = RandomState (17 )
1635+ dst_type = 'bfloat16' if (dtype_x_str == 'bfloat16' ) else None
1636+ dtype_x_str = 'float32' if (dtype_x_str == 'bfloat16' ) else dtype_x_str
16321637 x = np .array ([2 ** i for i in range (n_programs )], dtype = getattr (np , dtype_x_str ))
16331638 if mode == 'all_neg' :
16341639 x = - np .abs (x )
@@ -1640,12 +1645,17 @@ def kernel(X, Z):
16401645 if mode == 'max_pos' :
16411646 idx = rs .randint (n_programs , size = (1 , )).item ()
16421647 x [idx ] = np .max (np .abs (x )) + 1
1643- x_tri = to_triton (x , device = device )
1648+ x_tri = to_triton (x , device = device , dst_type = dst_type )
16441649
1645- z_tri = to_triton (np .array ([neutral ], dtype = getattr (np , dtype_x_str )), device = device )
1650+ z_tri = to_triton (np .array ([neutral ], dtype = getattr (np , dtype_x_str )), device = device , dst_type = dst_type )
16461651 h = kernel [(n_programs , )](x_tri , z_tri )
16471652 # torch result
1648- z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
1653+ if dst_type == 'bfloat16' :
1654+ z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
1655+ # trunc mantissa for a fair comparison of accuracy
1656+ z_ref = (z_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1657+ else :
1658+ z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
16491659 # compare
16501660 exact = op not in ['add' ]
16511661 if exact :
@@ -1656,6 +1666,12 @@ def kernel(X, Z):
16561666 if not is_cuda ():
16571667 return
16581668
1669+ # atom.add.bf16 is unsupported prior to Hopper so instead we generate an
1670+ # atom.cas add loop on Ampere and prior
1671+ if dst_type == 'bfloat16' and torch .cuda .get_device_capability ()[0 ] < 9 :
1672+ assert f"atom.{ sem_str } .global.cas" in h .asm ["ptx" ]
1673+ return
1674+
16591675 assert f"atom.global.gpu.{ sem_str } " in h .asm ["ptx" ]
16601676
16611677
@@ -1680,10 +1696,12 @@ def kernel(X):
16801696 for shape in [(2 , 2 ), (2 , 8 ), (8 , 2 ), (8 , 8 ), (32 , 32 ), (64 , 64 )]
16811697 for axis in [0 , 1 ]
16821698 for num_ctas in num_ctas_list
1683- for dtype_x_str in ['float16' , 'float32' , 'uint64' , 'int64' , 'float64' ]
1699+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' , 'uint64' , 'int64' , 'float64' ]
16841700 for check_return_val in ([True , False ] if is_hip () else [True ])])
16851701def test_tensor_atomic_rmw (shape , axis , num_ctas , dtype_x_str , check_return_val , device ):
16861702 check_type_supported (dtype_x_str , device )
1703+ if is_xpu () and dtype_x_str == 'bfloat16' :
1704+ pytest .skip ("bfloat16 not yet supported for xpu" )
16871705 shape0 , shape1 = shape
16881706 # triton kernel
16891707
@@ -1694,14 +1712,14 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
16941712 off1 = tl .arange (0 , SHAPE1 )
16951713 x = tl .load (X + off0 [:, None ] * SHAPE1 + off1 [None , :])
16961714
1697- if DTYPE == tl .float16 :
1715+ if DTYPE == tl .float16 or DTYPE == tl . bfloat16 :
16981716 # sum can have bad numerics when accumulating in float16.
16991717 # if we're dealing with float16, do the sum in float32.
17001718 x = x .to (tl .float32 )
17011719
17021720 z = tl .sum (x , axis = AXIS )
17031721
1704- if DTYPE == tl .float16 :
1722+ if DTYPE == tl .float16 or DTYPE == tl . bfloat16 :
17051723 z = z .to (DTYPE )
17061724
17071725 if AXIS == 1 :
@@ -1717,7 +1735,7 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
17171735 x = numpy_random ((shape0 , shape1 ), dtype_str = dtype_x_str , rs = rs )
17181736 z_shape = (shape0 , ) if axis == 1 else (shape1 , )
17191737 z = numpy_random (z_shape , dtype_str = dtype_x_str , rs = rs )
1720- old = np .zeros (z_shape , dtype = getattr ( np , dtype_x_str ) )
1738+ old = np .zeros (z_shape , dtype = z . dtype )
17211739 # reference results
17221740 if x .dtype == np .float16 :
17231741 # do the sum in float32 to reduce numerical variation
@@ -1726,17 +1744,31 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
17261744 z_ref = z + np .sum (x , axis = axis , keepdims = False )
17271745 old_ref = np .copy (z )
17281746 # triton result
1729- x_tri = to_triton (x , device = device )
1730- z_tri = to_triton (z , device = device )
1731- old_tri = to_triton (old , device = device )
1747+ x_tri = to_triton (x , device = device , dst_type = dtype_x_str )
1748+ z_tri = to_triton (z , device = device , dst_type = dtype_x_str )
1749+ old_tri = to_triton (old , device = device , dst_type = dtype_x_str )
17321750
17331751 def torch_to_triton_dtype (t ):
1752+ if t == torch .bfloat16 :
1753+ return tl .bfloat16
17341754 if t == torch .float16 :
17351755 return tl .float16
17361756 return None
17371757
17381758 kernel [(1 , )](z_tri , x_tri , old_tri , axis , shape0 , shape1 , torch_to_triton_dtype (x_tri .dtype ), check_return_val ,
17391759 num_ctas = num_ctas )
1760+
1761+ if dtype_x_str == 'bfloat16' :
1762+ # trunc mantissa for a fair comparison of accuracy
1763+ z_ref = (z_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1764+ old_ref = (old_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1765+ # mantissa trunc is not enough, bump up the relative tolerance as well
1766+ np .testing .assert_allclose (z_ref , to_numpy (z_tri ), rtol = 0.5 )
1767+ # check return vals, but use assert_allclose for bf16
1768+ if check_return_val :
1769+ np .testing .assert_allclose (old_ref , to_numpy (old_tri ), rtol = 0.5 )
1770+ return
1771+
17401772 np .testing .assert_allclose (z_ref , to_numpy (z_tri ), rtol = 1e-4 )
17411773 if check_return_val :
17421774 np .testing .assert_equal (old_ref , to_numpy (old_tri ))
@@ -1746,8 +1778,11 @@ def torch_to_triton_dtype(t):
17461778@pytest .mark .parametrize ("size, num_ctas, dtype_x_str" , [(size , num_ctas , dtype_x_str )
17471779 for size in [2 , 4 , 8 , 32 , 64 , 128 ]
17481780 for num_ctas in num_ctas_list
1749- for dtype_x_str in ['float16' , 'float32' ]])
1781+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
17501782def test_tensor_atomic_add_non_exclusive_offset (size , num_ctas , dtype_x_str , device ):
1783+ check_type_supported (dtype_x_str , device )
1784+ if is_xpu () and dtype_x_str == 'bfloat16' :
1785+ pytest .skip ("bfloat16 not yet supported for xpu" )
17511786
17521787 @triton .jit
17531788 def kernel (X , val , NUM : tl .constexpr ):
@@ -1757,8 +1792,9 @@ def kernel(X, val, NUM: tl.constexpr):
17571792 tl .atomic_add (X + offset // 2 , val )
17581793
17591794 shape = (size // 2 , size )
1760- x = torch .zeros (shape , dtype = getattr (torch , dtype_x_str ), device = device )
1761- val = torch .randn ((size ** 2 ), dtype = getattr (torch , dtype_x_str ), device = device )
1795+ dtype = getattr (torch , dtype_x_str )
1796+ x = torch .zeros (shape , dtype = dtype , device = device )
1797+ val = torch .randn ((size ** 2 ), dtype = dtype , device = device )
17621798 kernel [(1 , )](x , val , size , num_warps = 1 , num_ctas = num_ctas )
17631799 ref = val [0 ::2 ] + val [1 ::2 ]
17641800 torch .testing .assert_close (ref , x .reshape (math .prod (shape )))
@@ -1768,9 +1804,11 @@ def kernel(X, val, NUM: tl.constexpr):
17681804@pytest .mark .parametrize ("size, num_ctas, dtype_x_str" , [(size , num_ctas , dtype_x_str )
17691805 for size in [2 , 4 , 8 , 32 , 64 , 128 ]
17701806 for num_ctas in num_ctas_list
1771- for dtype_x_str in ['float16' , 'float32' ]])
1807+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
17721808def test_tensor_atomic_add_shift_1 (size , num_ctas , dtype_x_str , device ):
17731809 check_type_supported (dtype_x_str , device )
1810+ if is_xpu () and dtype_x_str == 'bfloat16' :
1811+ pytest .skip ("bfloat16 not yet supported for xpu" )
17741812
17751813 @triton .jit
17761814 def kernel (X , val , NUM : tl .constexpr ):
@@ -1801,12 +1839,15 @@ def kernel(X, val, NUM: tl.constexpr):
18011839 for idx_order in ['increase' , 'decrease' , 'random_no_duplication' , 'random' ]
18021840 for mask_step in range (1 , 5 )
18031841 for num_ctas in num_ctas_list
1804- for dtype_x_str in ['float16' , 'float32' ]])
1842+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
18051843def test_tensor_atomic_add_access_patterns (shape , idx_order , mask_step , num_ctas , dtype_x_str , device ):
18061844 check_type_supported (dtype_x_str , device )
18071845 if is_interpreter ():
18081846 pytest .xfail ("not supported in the interpreter" )
18091847
1848+ if is_xpu () and dtype_x_str == 'bfloat16' :
1849+ pytest .skip ("bfloat16 not yet supported for xpu" )
1850+
18101851 @triton .jit
18111852 def kernel (in_ptr , idx_ptr , out_ptr , shape0 , shape1 , mask_step , XBLOCK : tl .constexpr ):
18121853 xoffset = tl .program_id (0 ) * XBLOCK
@@ -1829,8 +1870,9 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
18291870 if idx_order == 'random' :
18301871 idx = torch .randint (0 , shape1 , size = (shape0 , shape1 ), device = device )
18311872
1832- val = torch .randn ((shape0 , shape1 ), dtype = getattr (torch , dtype_x_str ), device = device )
1833- dst = torch .randn ((shape0 , shape1 ), dtype = getattr (torch , dtype_x_str ), device = device )
1873+ dtype = getattr (torch , dtype_x_str )
1874+ val = torch .randn ((shape0 , shape1 ), dtype = dtype , device = device )
1875+ dst = torch .randn ((shape0 , shape1 ), dtype = dtype , device = device )
18341876
18351877 dst_ref = dst .clone ()
18361878
@@ -1842,6 +1884,11 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
18421884 cnt += 1
18431885
18441886 kernel [(1 , )](val , idx , dst , shape0 , shape1 , mask_step , 64 , num_ctas = num_ctas )
1887+
1888+ if dtype_x_str == 'bfloat16' :
1889+ torch .testing .assert_close (dst_ref , dst , rtol = 0.1 , atol = 0.1 )
1890+ return
1891+
18451892 np .testing .assert_allclose (to_numpy (dst_ref ), to_numpy (dst ), atol = 1e-2 )
18461893
18471894
@@ -3248,6 +3295,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
32483295 pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
32493296 if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024 :
32503297 pytest .xfail ("Skipping sum reduction on float16 due to accuracy issues" )
3298+ if isinstance (src_layout , LinearLayout ) and THREADS_PER_WARP != (1 << len (src_layout .lane )):
3299+ pytest .xfail (f"Skipping. This LinearLayout assumes { 1 << len (src_layout .lane )} threads per warp" )
32513300
32523301 if isinstance (src_layout , MmaLayout ) and src_layout .version == 3 :
32533302 src_layout .instr_shape [2 ] = 16 if dtype_str == "float16" else 8
@@ -7646,11 +7695,11 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
76467695 pat += str (axis )
76477696 pat += r" : i32, efficient_layout} : \(tensor\<"
76487697 pat += src_spec
7649- pat += r", (#[a-z]+[0-9]* )\>, tensor\<"
7698+ pat += r", (#[a-z]+[0-9]+ )\>, tensor\<"
76507699 pat += indices_spec
7651- pat += r", (#[a-z]+[0-9]* )\>\) -> tensor\<"
7700+ pat += r", (#[a-z]+[0-9]+ )\>\) -> tensor\<"
76527701 pat += output_spec
7653- pat += r", (#[a-z]+[0-9]* )\>"
7702+ pat += r", (#[a-z]+[0-9]+ )\>"
76547703
76557704 repl = r"""
76567705 %src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout>
0 commit comments