Skip to content

Commit 444bccd

Browse files
authored
Enable bf16x3, bf16x6 for dot input precisions (#5403)
Fixes #5364. Enables two new precision types (bf16x3, bf16x6) for dot inputs.
1 parent e89a6f7 commit 444bccd

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

python/test/unit/language/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3272,8 +3272,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
32723272
pytest.xfail(f"input_precision {input_precision} is not supported in the interpreter")
32733273
else:
32743274
if is_xpu():
3275-
if input_precision in ("bf16x3", "bf16x6"):
3276-
pytest.skip(f"input_precision {input_precision} is not supported")
32773275
if (M < 8 or N < 16 or (K < 16 and in_dtype == 'float16') or (K < 8 and in_dtype == 'float32')):
32783276
pytest.xfail("XPU: small dots are not supported")
32793277
elif not is_hip() and K < 16:

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class XPUOptions:
3030
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4nv", "fp8e4b15")
3131
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
3232
default_dot_input_precision: str = "tf32"
33-
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
33+
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6')
3434
allow_fp8e4nv: bool = False
3535
allow_fp8e4b15: bool = True
3636
grf_mode: tuple = ('small', 'large', 'auto', 'default')

0 commit comments

Comments
 (0)