@@ -1352,6 +1352,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
13521352 pytest .xfail ("Only test atomic bfloat16/float16 ops on GPU" )
13531353 if "uint" in dtype_x_str and mode in ["min_neg" , "all_neg" ]:
13541354 pytest .xfail ("uint cannot be negative" )
1355+ if is_xpu () and dtype_x_str == 'bfloat16' :
1356+ pytest .skip ("bfloat16 not yet supported for xpu" )
13551357
13561358 n_programs = 5
13571359
@@ -1440,6 +1442,8 @@ def kernel(X):
14401442 for check_return_val in ([True , False ] if is_hip () else [True ])])
14411443def test_tensor_atomic_rmw (shape , axis , num_ctas , dtype_x_str , check_return_val , device ):
14421444 check_type_supported (dtype_x_str , device )
1445+ if is_xpu () and dtype_x_str == 'bfloat16' :
1446+ pytest .skip ("bfloat16 not yet supported for xpu" )
14431447 shape0 , shape1 = shape
14441448 # triton kernel
14451449
@@ -1519,6 +1523,8 @@ def torch_to_triton_dtype(t):
15191523 for dtype_x_str in ['bfloat16' , 'float16' , 'float32' ]])
15201524def test_tensor_atomic_add_non_exclusive_offset (size , num_ctas , dtype_x_str , device ):
15211525 check_type_supported (dtype_x_str , device )
1526+ if is_xpu () and dtype_x_str == 'bfloat16' :
1527+ pytest .skip ("bfloat16 not yet supported for xpu" )
15221528
15231529 @triton .jit
15241530 def kernel (X , val , NUM : tl .constexpr ):
@@ -1543,6 +1549,8 @@ def kernel(X, val, NUM: tl.constexpr):
15431549 for dtype_x_str in ['bfloat16' , 'float16' , 'float32' ]])
15441550def test_tensor_atomic_add_shift_1 (size , num_ctas , dtype_x_str , device ):
15451551 check_type_supported (dtype_x_str , device )
1552+ if is_xpu () and dtype_x_str == 'bfloat16' :
1553+ pytest .skip ("bfloat16 not yet supported for xpu" )
15461554
15471555 @triton .jit
15481556 def kernel (X , val , NUM : tl .constexpr ):
@@ -1579,6 +1587,9 @@ def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas
15791587 if is_interpreter ():
15801588 pytest .xfail ("not supported in the interpreter" )
15811589
1590+ if is_xpu () and dtype_x_str == 'bfloat16' :
1591+ pytest .skip ("bfloat16 not yet supported for xpu" )
1592+
15821593 @triton .jit
15831594 def kernel (in_ptr , idx_ptr , out_ptr , shape0 , shape1 , mask_step , XBLOCK : tl .constexpr ):
15841595 xoffset = tl .program_id (0 ) * XBLOCK
0 commit comments