@@ -1610,6 +1610,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
1610
1610
pytest .xfail ("Only test atomic bfloat16/float16 ops on GPU" )
1611
1611
if "uint" in dtype_x_str and mode in ["min_neg" , "all_neg" ]:
1612
1612
pytest .xfail ("uint cannot be negative" )
1613
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1614
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1613
1615
1614
1616
n_programs = 5
1615
1617
@@ -1698,6 +1700,8 @@ def kernel(X):
1698
1700
for check_return_val in ([True , False ] if is_hip () else [True ])])
1699
1701
def test_tensor_atomic_rmw (shape , axis , num_ctas , dtype_x_str , check_return_val , device ):
1700
1702
check_type_supported (dtype_x_str , device )
1703
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1704
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1701
1705
shape0 , shape1 = shape
1702
1706
# triton kernel
1703
1707
@@ -1777,6 +1781,8 @@ def torch_to_triton_dtype(t):
1777
1781
for dtype_x_str in ['bfloat16' , 'float16' , 'float32' ]])
1778
1782
def test_tensor_atomic_add_non_exclusive_offset (size , num_ctas , dtype_x_str , device ):
1779
1783
check_type_supported (dtype_x_str , device )
1784
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1785
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1780
1786
1781
1787
@triton .jit
1782
1788
def kernel (X , val , NUM : tl .constexpr ):
@@ -1798,9 +1804,11 @@ def kernel(X, val, NUM: tl.constexpr):
1798
1804
@pytest .mark .parametrize ("size, num_ctas, dtype_x_str" , [(size , num_ctas , dtype_x_str )
1799
1805
for size in [2 , 4 , 8 , 32 , 64 , 128 ]
1800
1806
for num_ctas in num_ctas_list
1801
- for dtype_x_str in ['float16' , 'float32' ]])
1807
+ for dtype_x_str in ['bfloat16' , ' float16' , 'float32' ]])
1802
1808
def test_tensor_atomic_add_shift_1 (size , num_ctas , dtype_x_str , device ):
1803
1809
check_type_supported (dtype_x_str , device )
1810
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1811
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1804
1812
1805
1813
@triton .jit
1806
1814
def kernel (X , val , NUM : tl .constexpr ):
@@ -1837,6 +1845,9 @@ def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas
1837
1845
if is_interpreter ():
1838
1846
pytest .xfail ("not supported in the interpreter" )
1839
1847
1848
+ if is_xpu () and dtype_x_str == 'bfloat16' :
1849
+ pytest .skip ("bfloat16 not yet supported for xpu" )
1850
+
1840
1851
@triton .jit
1841
1852
def kernel (in_ptr , idx_ptr , out_ptr , shape0 , shape1 , mask_step , XBLOCK : tl .constexpr ):
1842
1853
xoffset = tl .program_id (0 ) * XBLOCK
@@ -3284,6 +3295,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
3284
3295
pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
3285
3296
if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024 :
3286
3297
pytest .xfail ("Skipping sum reduction on float16 due to accuracy issues" )
3298
+ if isinstance (src_layout , LinearLayout ) and THREADS_PER_WARP != (1 << len (src_layout .lane )):
3299
+ pytest .xfail (f"Skipping. This LinearLayout assumes { 1 << len (src_layout .lane )} threads per warp" )
3287
3300
3288
3301
if isinstance (src_layout , MmaLayout ) and src_layout .version == 3 :
3289
3302
src_layout .instr_shape [2 ] = 16 if dtype_str == "float16" else 8
@@ -7682,11 +7695,11 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
7682
7695
pat += str (axis )
7683
7696
pat += r" : i32, efficient_layout} : \(tensor\<"
7684
7697
pat += src_spec
7685
- pat += r", (#[a-z]+[0-9]* )\>, tensor\<"
7698
+ pat += r", (#[a-z]+[0-9]+ )\>, tensor\<"
7686
7699
pat += indices_spec
7687
- pat += r", (#[a-z]+[0-9]* )\>\) -> tensor\<"
7700
+ pat += r", (#[a-z]+[0-9]+ )\>\) -> tensor\<"
7688
7701
pat += output_spec
7689
- pat += r", (#[a-z]+[0-9]* )\>"
7702
+ pat += r", (#[a-z]+[0-9]+ )\>"
7690
7703
7691
7704
repl = r"""
7692
7705
%src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout>
0 commit comments