@@ -1352,8 +1352,6 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
13521352 pytest .xfail ("Only test atomic bfloat16/float16 ops on GPU" )
13531353 if "uint" in dtype_x_str and mode in ["min_neg" , "all_neg" ]:
13541354 pytest .xfail ("uint cannot be negative" )
1355- if is_xpu () and dtype_x_str == 'bfloat16' :
1356- pytest .skip ("bfloat16 not yet supported for xpu" )
13571355
13581356 n_programs = 5
13591357
@@ -1442,8 +1440,6 @@ def kernel(X):
14421440 for check_return_val in ([True , False ] if is_hip () else [True ])])
14431441def test_tensor_atomic_rmw (shape , axis , num_ctas , dtype_x_str , check_return_val , device ):
14441442 check_type_supported (dtype_x_str , device )
1445- if is_xpu () and dtype_x_str == 'bfloat16' :
1446- pytest .skip ("bfloat16 not yet supported for xpu" )
14471443 shape0 , shape1 = shape
14481444 # triton kernel
14491445
@@ -1523,8 +1519,6 @@ def torch_to_triton_dtype(t):
15231519 for dtype_x_str in ['bfloat16' , 'float16' , 'float32' ]])
15241520def test_tensor_atomic_add_non_exclusive_offset (size , num_ctas , dtype_x_str , device ):
15251521 check_type_supported (dtype_x_str , device )
1526- if is_xpu () and dtype_x_str == 'bfloat16' :
1527- pytest .skip ("bfloat16 not yet supported for xpu" )
15281522
15291523 @triton .jit
15301524 def kernel (X , val , NUM : tl .constexpr ):
@@ -1549,8 +1543,6 @@ def kernel(X, val, NUM: tl.constexpr):
15491543 for dtype_x_str in ['bfloat16' , 'float16' , 'float32' ]])
15501544def test_tensor_atomic_add_shift_1 (size , num_ctas , dtype_x_str , device ):
15511545 check_type_supported (dtype_x_str , device )
1552- if is_xpu () and dtype_x_str == 'bfloat16' :
1553- pytest .skip ("bfloat16 not yet supported for xpu" )
15541546
15551547 @triton .jit
15561548 def kernel (X , val , NUM : tl .constexpr ):
@@ -1587,9 +1579,6 @@ def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas
15871579 if is_interpreter ():
15881580 pytest .xfail ("not supported in the interpreter" )
15891581
1590- if is_xpu () and dtype_x_str == 'bfloat16' :
1591- pytest .skip ("bfloat16 not yet supported for xpu" )
1592-
15931582 @triton .jit
15941583 def kernel (in_ptr , idx_ptr , out_ptr , shape0 , shape1 , mask_step , XBLOCK : tl .constexpr ):
15951584 xoffset = tl .program_id (0 ) * XBLOCK
0 commit comments