@@ -1580,6 +1580,7 @@ def kernel(X, Y, Z):
1580
1580
@pytest .mark .parametrize (
1581
1581
"op, dtype_x_str, mode, sem" ,
1582
1582
itertools .chain .from_iterable ([[
1583
+ ('add' , 'bfloat16' , mode , sem ),
1583
1584
('add' , 'float16' , mode , sem ),
1584
1585
('add' , 'uint32' , mode , sem ),
1585
1586
('add' , 'int32' , mode , sem ),
@@ -1609,6 +1610,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
1609
1610
pytest .xfail ("Only test atomic bfloat16/float16 ops on GPU" )
1610
1611
if "uint" in dtype_x_str and mode in ["min_neg" , "all_neg" ]:
1611
1612
pytest .xfail ("uint cannot be negative" )
1613
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1614
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1612
1615
1613
1616
n_programs = 5
1614
1617
@@ -1623,12 +1626,14 @@ def kernel(X, Z):
1623
1626
sem_arg = sem if sem is None else f'"{ sem } "'
1624
1627
kernel = patch_kernel (kernel , {'GENERATE_TEST_HERE' : f'tl.atomic_{ op } (Z, x, sem={ sem_arg } )' })
1625
1628
numpy_op = {'add' : np .sum , 'max' : np .max , 'min' : np .min }[op ]
1626
- max_neutral = float ('-inf' ) if dtype_x_str in float_dtypes else np .iinfo (getattr (np , dtype_x_str )).min
1627
- min_neutral = float ('inf' ) if dtype_x_str in float_dtypes else np .iinfo (getattr (np , dtype_x_str )).max
1629
+ max_neutral = float ('-inf' ) if dtype_x_str in float_dtypes_with_bfloat16 else np .iinfo (getattr (np , dtype_x_str )).min
1630
+ min_neutral = float ('inf' ) if dtype_x_str in float_dtypes_with_bfloat16 else np .iinfo (getattr (np , dtype_x_str )).max
1628
1631
neutral = {'add' : 0 , 'max' : max_neutral , 'min' : min_neutral }[op ]
1629
1632
1630
1633
# triton result
1631
1634
rs = RandomState (17 )
1635
+ dst_type = 'bfloat16' if (dtype_x_str == 'bfloat16' ) else None
1636
+ dtype_x_str = 'float32' if (dtype_x_str == 'bfloat16' ) else dtype_x_str
1632
1637
x = np .array ([2 ** i for i in range (n_programs )], dtype = getattr (np , dtype_x_str ))
1633
1638
if mode == 'all_neg' :
1634
1639
x = - np .abs (x )
@@ -1640,12 +1645,17 @@ def kernel(X, Z):
1640
1645
if mode == 'max_pos' :
1641
1646
idx = rs .randint (n_programs , size = (1 , )).item ()
1642
1647
x [idx ] = np .max (np .abs (x )) + 1
1643
- x_tri = to_triton (x , device = device )
1648
+ x_tri = to_triton (x , device = device , dst_type = dst_type )
1644
1649
1645
- z_tri = to_triton (np .array ([neutral ], dtype = getattr (np , dtype_x_str )), device = device )
1650
+ z_tri = to_triton (np .array ([neutral ], dtype = getattr (np , dtype_x_str )), device = device , dst_type = dst_type )
1646
1651
h = kernel [(n_programs , )](x_tri , z_tri )
1647
1652
# torch result
1648
- z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
1653
+ if dst_type == 'bfloat16' :
1654
+ z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
1655
+ # trunc mantissa for a fair comparison of accuracy
1656
+ z_ref = (z_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1657
+ else :
1658
+ z_ref = numpy_op (x ).astype (getattr (np , dtype_x_str ))
1649
1659
# compare
1650
1660
exact = op not in ['add' ]
1651
1661
if exact :
@@ -1656,6 +1666,12 @@ def kernel(X, Z):
1656
1666
if not is_cuda ():
1657
1667
return
1658
1668
1669
+ # atom.add.bf16 is unsupported prior to Hopper so instead we generate an
1670
+ # atom.cas add loop on Ampere and prior
1671
+ if dst_type == 'bfloat16' and torch .cuda .get_device_capability ()[0 ] < 9 :
1672
+ assert f"atom.{ sem_str } .global.cas" in h .asm ["ptx" ]
1673
+ return
1674
+
1659
1675
assert f"atom.global.gpu.{ sem_str } " in h .asm ["ptx" ]
1660
1676
1661
1677
@@ -1680,10 +1696,12 @@ def kernel(X):
1680
1696
for shape in [(2 , 2 ), (2 , 8 ), (8 , 2 ), (8 , 8 ), (32 , 32 ), (64 , 64 )]
1681
1697
for axis in [0 , 1 ]
1682
1698
for num_ctas in num_ctas_list
1683
- for dtype_x_str in ['float16' , 'float32' , 'uint64' , 'int64' , 'float64' ]
1699
+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' , 'uint64' , 'int64' , 'float64' ]
1684
1700
for check_return_val in ([True , False ] if is_hip () else [True ])])
1685
1701
def test_tensor_atomic_rmw (shape , axis , num_ctas , dtype_x_str , check_return_val , device ):
1686
1702
check_type_supported (dtype_x_str , device )
1703
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1704
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1687
1705
shape0 , shape1 = shape
1688
1706
# triton kernel
1689
1707
@@ -1694,14 +1712,14 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
1694
1712
off1 = tl .arange (0 , SHAPE1 )
1695
1713
x = tl .load (X + off0 [:, None ] * SHAPE1 + off1 [None , :])
1696
1714
1697
- if DTYPE == tl .float16 :
1715
+ if DTYPE == tl .float16 or DTYPE == tl . bfloat16 :
1698
1716
# sum can have bad numerics when accumulating in float16.
1699
1717
# if we're dealing with float16, do the sum in float32.
1700
1718
x = x .to (tl .float32 )
1701
1719
1702
1720
z = tl .sum (x , axis = AXIS )
1703
1721
1704
- if DTYPE == tl .float16 :
1722
+ if DTYPE == tl .float16 or DTYPE == tl . bfloat16 :
1705
1723
z = z .to (DTYPE )
1706
1724
1707
1725
if AXIS == 1 :
@@ -1717,7 +1735,7 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
1717
1735
x = numpy_random ((shape0 , shape1 ), dtype_str = dtype_x_str , rs = rs )
1718
1736
z_shape = (shape0 , ) if axis == 1 else (shape1 , )
1719
1737
z = numpy_random (z_shape , dtype_str = dtype_x_str , rs = rs )
1720
- old = np .zeros (z_shape , dtype = getattr ( np , dtype_x_str ) )
1738
+ old = np .zeros (z_shape , dtype = z . dtype )
1721
1739
# reference results
1722
1740
if x .dtype == np .float16 :
1723
1741
# do the sum in float32 to reduce numerical variation
@@ -1726,17 +1744,31 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
1726
1744
z_ref = z + np .sum (x , axis = axis , keepdims = False )
1727
1745
old_ref = np .copy (z )
1728
1746
# triton result
1729
- x_tri = to_triton (x , device = device )
1730
- z_tri = to_triton (z , device = device )
1731
- old_tri = to_triton (old , device = device )
1747
+ x_tri = to_triton (x , device = device , dst_type = dtype_x_str )
1748
+ z_tri = to_triton (z , device = device , dst_type = dtype_x_str )
1749
+ old_tri = to_triton (old , device = device , dst_type = dtype_x_str )
1732
1750
1733
1751
def torch_to_triton_dtype (t ):
1752
+ if t == torch .bfloat16 :
1753
+ return tl .bfloat16
1734
1754
if t == torch .float16 :
1735
1755
return tl .float16
1736
1756
return None
1737
1757
1738
1758
kernel [(1 , )](z_tri , x_tri , old_tri , axis , shape0 , shape1 , torch_to_triton_dtype (x_tri .dtype ), check_return_val ,
1739
1759
num_ctas = num_ctas )
1760
+
1761
+ if dtype_x_str == 'bfloat16' :
1762
+ # trunc mantissa for a fair comparison of accuracy
1763
+ z_ref = (z_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1764
+ old_ref = (old_ref .view ('uint32' ) & np .uint32 (0xffff0000 )).view ('float32' )
1765
+ # mantissa trunc is not enough, bump up the relative tolerance as well
1766
+ np .testing .assert_allclose (z_ref , to_numpy (z_tri ), rtol = 0.5 )
1767
+ # check return vals, but use assert_allclose for bf16
1768
+ if check_return_val :
1769
+ np .testing .assert_allclose (old_ref , to_numpy (old_tri ), rtol = 0.5 )
1770
+ return
1771
+
1740
1772
np .testing .assert_allclose (z_ref , to_numpy (z_tri ), rtol = 1e-4 )
1741
1773
if check_return_val :
1742
1774
np .testing .assert_equal (old_ref , to_numpy (old_tri ))
@@ -1746,8 +1778,11 @@ def torch_to_triton_dtype(t):
1746
1778
@pytest .mark .parametrize ("size, num_ctas, dtype_x_str" , [(size , num_ctas , dtype_x_str )
1747
1779
for size in [2 , 4 , 8 , 32 , 64 , 128 ]
1748
1780
for num_ctas in num_ctas_list
1749
- for dtype_x_str in ['float16' , 'float32' ]])
1781
+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
1750
1782
def test_tensor_atomic_add_non_exclusive_offset (size , num_ctas , dtype_x_str , device ):
1783
+ check_type_supported (dtype_x_str , device )
1784
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1785
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1751
1786
1752
1787
@triton .jit
1753
1788
def kernel (X , val , NUM : tl .constexpr ):
@@ -1757,8 +1792,9 @@ def kernel(X, val, NUM: tl.constexpr):
1757
1792
tl .atomic_add (X + offset // 2 , val )
1758
1793
1759
1794
shape = (size // 2 , size )
1760
- x = torch .zeros (shape , dtype = getattr (torch , dtype_x_str ), device = device )
1761
- val = torch .randn ((size ** 2 ), dtype = getattr (torch , dtype_x_str ), device = device )
1795
+ dtype = getattr (torch , dtype_x_str )
1796
+ x = torch .zeros (shape , dtype = dtype , device = device )
1797
+ val = torch .randn ((size ** 2 ), dtype = dtype , device = device )
1762
1798
kernel [(1 , )](x , val , size , num_warps = 1 , num_ctas = num_ctas )
1763
1799
ref = val [0 ::2 ] + val [1 ::2 ]
1764
1800
torch .testing .assert_close (ref , x .reshape (math .prod (shape )))
@@ -1768,9 +1804,11 @@ def kernel(X, val, NUM: tl.constexpr):
1768
1804
@pytest .mark .parametrize ("size, num_ctas, dtype_x_str" , [(size , num_ctas , dtype_x_str )
1769
1805
for size in [2 , 4 , 8 , 32 , 64 , 128 ]
1770
1806
for num_ctas in num_ctas_list
1771
- for dtype_x_str in ['float16' , 'float32' ]])
1807
+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
1772
1808
def test_tensor_atomic_add_shift_1 (size , num_ctas , dtype_x_str , device ):
1773
1809
check_type_supported (dtype_x_str , device )
1810
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1811
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1774
1812
1775
1813
@triton .jit
1776
1814
def kernel (X , val , NUM : tl .constexpr ):
@@ -1801,12 +1839,15 @@ def kernel(X, val, NUM: tl.constexpr):
1801
1839
for idx_order in ['increase' , 'decrease' , 'random_no_duplication' , 'random' ]
1802
1840
for mask_step in range (1 , 5 )
1803
1841
for num_ctas in num_ctas_list
1804
- for dtype_x_str in ['float16' , 'float32' ]])
1842
+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
1805
1843
def test_tensor_atomic_add_access_patterns (shape , idx_order , mask_step , num_ctas , dtype_x_str , device ):
1806
1844
check_type_supported (dtype_x_str , device )
1807
1845
if is_interpreter ():
1808
1846
pytest .xfail ("not supported in the interpreter" )
1809
1847
1848
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1849
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1850
+
1810
1851
@triton .jit
1811
1852
def kernel (in_ptr , idx_ptr , out_ptr , shape0 , shape1 , mask_step , XBLOCK : tl .constexpr ):
1812
1853
xoffset = tl .program_id (0 ) * XBLOCK
@@ -1829,8 +1870,9 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
1829
1870
if idx_order == 'random' :
1830
1871
idx = torch .randint (0 , shape1 , size = (shape0 , shape1 ), device = device )
1831
1872
1832
- val = torch .randn ((shape0 , shape1 ), dtype = getattr (torch , dtype_x_str ), device = device )
1833
- dst = torch .randn ((shape0 , shape1 ), dtype = getattr (torch , dtype_x_str ), device = device )
1873
+ dtype = getattr (torch , dtype_x_str )
1874
+ val = torch .randn ((shape0 , shape1 ), dtype = dtype , device = device )
1875
+ dst = torch .randn ((shape0 , shape1 ), dtype = dtype , device = device )
1834
1876
1835
1877
dst_ref = dst .clone ()
1836
1878
@@ -1842,6 +1884,11 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
1842
1884
cnt += 1
1843
1885
1844
1886
kernel [(1 , )](val , idx , dst , shape0 , shape1 , mask_step , 64 , num_ctas = num_ctas )
1887
+
1888
+ if dtype_x_str == 'bfloat16' :
1889
+ torch .testing .assert_close (dst_ref , dst , rtol = 0.1 , atol = 0.1 )
1890
+ return
1891
+
1845
1892
np .testing .assert_allclose (to_numpy (dst_ref ), to_numpy (dst ), atol = 1e-2 )
1846
1893
1847
1894
@@ -3248,6 +3295,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
3248
3295
pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
3249
3296
if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024 :
3250
3297
pytest .xfail ("Skipping sum reduction on float16 due to accuracy issues" )
3298
+ if isinstance (src_layout , LinearLayout ) and THREADS_PER_WARP != (1 << len (src_layout .lane )):
3299
+ pytest .xfail (f"Skipping. This LinearLayout assumes { 1 << len (src_layout .lane )} threads per warp" )
3251
3300
3252
3301
if isinstance (src_layout , MmaLayout ) and src_layout .version == 3 :
3253
3302
src_layout .instr_shape [2 ] = 16 if dtype_str == "float16" else 8
@@ -7646,11 +7695,11 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
7646
7695
pat += str (axis )
7647
7696
pat += r" : i32, efficient_layout} : \(tensor\<"
7648
7697
pat += src_spec
7649
- pat += r", (#[a-z]+[0-9]* )\>, tensor\<"
7698
+ pat += r", (#[a-z]+[0-9]+ )\>, tensor\<"
7650
7699
pat += indices_spec
7651
- pat += r", (#[a-z]+[0-9]* )\>\) -> tensor\<"
7700
+ pat += r", (#[a-z]+[0-9]+ )\>\) -> tensor\<"
7652
7701
pat += output_spec
7653
- pat += r", (#[a-z]+[0-9]* )\>"
7702
+ pat += r", (#[a-z]+[0-9]+ )\>"
7654
7703
7655
7704
repl = r"""
7656
7705
%src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout>
0 commit comments