Skip to content

Commit 225cdbd

Browse files
Revert "Enable bfloat16 unit tests (#5163)" (#5589)
This reverts commit 91cac59.
1 parent 761fcb0 commit 225cdbd

File tree

6 files changed

+35
-40
lines changed

6 files changed

+35
-40
lines changed

python/test/unit/language/test_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
13521352
pytest.xfail("Only test atomic bfloat16/float16 ops on GPU")
13531353
if "uint" in dtype_x_str and mode in ["min_neg", "all_neg"]:
13541354
pytest.xfail("uint cannot be negative")
1355+
if is_xpu() and dtype_x_str == 'bfloat16':
1356+
pytest.skip("bfloat16 not yet supported for xpu")
13551357

13561358
n_programs = 5
13571359

@@ -1440,6 +1442,8 @@ def kernel(X):
14401442
for check_return_val in ([True, False] if is_hip() else [True])])
14411443
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device):
14421444
check_type_supported(dtype_x_str, device)
1445+
if is_xpu() and dtype_x_str == 'bfloat16':
1446+
pytest.skip("bfloat16 not yet supported for xpu")
14431447
shape0, shape1 = shape
14441448
# triton kernel
14451449

@@ -1519,6 +1523,8 @@ def torch_to_triton_dtype(t):
15191523
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
15201524
def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device):
15211525
check_type_supported(dtype_x_str, device)
1526+
if is_xpu() and dtype_x_str == 'bfloat16':
1527+
pytest.skip("bfloat16 not yet supported for xpu")
15221528

15231529
@triton.jit
15241530
def kernel(X, val, NUM: tl.constexpr):
@@ -1543,6 +1549,8 @@ def kernel(X, val, NUM: tl.constexpr):
15431549
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
15441550
def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device):
15451551
check_type_supported(dtype_x_str, device)
1552+
if is_xpu() and dtype_x_str == 'bfloat16':
1553+
pytest.skip("bfloat16 not yet supported for xpu")
15461554

15471555
@triton.jit
15481556
def kernel(X, val, NUM: tl.constexpr):
@@ -1579,6 +1587,9 @@ def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas
15791587
if is_interpreter():
15801588
pytest.xfail("not supported in the interpreter")
15811589

1590+
if is_xpu() and dtype_x_str == 'bfloat16':
1591+
pytest.skip("bfloat16 not yet supported for xpu")
1592+
15821593
@triton.jit
15831594
def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr):
15841595
xoffset = tl.program_id(0) * XBLOCK

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,9 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK
15671567
pytest.xfail("Multi-CTA not supported")
15681568
if is_hip_cdna3() and (kind, dtype_str, M_BLOCK, N_BLOCK) in REDUCE_SKIP_HIP_CDNA3:
15691569
pytest.skip("Broken on rocm")
1570+
if is_xpu():
1571+
if (kind, dtype_str) in [("add", "bfloat16")]:
1572+
pytest.skip("FIXME: issue #3914")
15701573

15711574
@triton.jit(debug=True)
15721575
def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr):

scripts/skiplist/lts/language.txt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4665
22
python/test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float64-float64]
33
python/test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float64-float64]
4-
# Below bfloat16 tests require IGC 1188 or above
5-
python/test/unit/language/test_core.py::test_atomic_rmw[r".*bfloat16.*"]@regexp
6-
python/test/unit/language/test_core.py::test_tensor_atomic_rmw[r".*bfloat16.*"]@regexp
7-
python/test/unit/language/test_core.py::test_tensor_atomic_add_non_exclusive_offset[r".*bfloat16.*"]@regexp
8-
python/test/unit/language/test_core.py::test_tensor_atomic_add_shift_1[r".*bfloat16.*"]@regexp
9-
python/test/unit/language/test_core.py::test_tensor_atomic_add_access_patterns[r".*bfloat16.*"]@regexp
10-
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[r".*bfloat16.*"]@regexp

setup.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -489,15 +489,6 @@ def build_extension(self, ext):
489489
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
490490
cmake_args.extend(thirdparty_cmake_args)
491491

492-
result = subprocess.run(["bash", "./scripts/capture-hw-details.sh"], stdout=subprocess.PIPE,
493-
stderr=subprocess.PIPE, check=True, text=True, env=os.environ.copy())
494-
agama_version = None
495-
for line in result.stdout.splitlines():
496-
if line.startswith("AGAMA_VERSION="):
497-
agama_version = line.split("=", 1)[1].strip()
498-
break
499-
cmake_args.append(f"-DAGAMA_VERSION={agama_version}")
500-
501492
# configuration
502493
cfg = get_build_type()
503494
build_args = ["--config", cfg]

third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,28 @@ if (NOT SPIRVToLLVMTranslator_FOUND)
2626

2727
FetchContent_MakeAvailable(spirv-llvm-translator)
2828

29-
# FIXME: Don't apply patch when LTS driver is updated.
30-
if(DEFINED AGAMA_VERSION AND AGAMA_VERSION STREQUAL "1146")
29+
# FIXME: Don't apply patch when Agama driver is updated.
30+
execute_process(
31+
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
32+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
33+
ERROR_QUIET
34+
RESULT_VARIABLE PATCH_RESULT
35+
)
36+
if(PATCH_RESULT EQUAL 0)
3137
execute_process(
32-
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
33-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
34-
ERROR_QUIET
35-
RESULT_VARIABLE PATCH_RESULT
38+
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
39+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
40+
RESULT_VARIABLE PATCH_RESULT
3641
)
37-
if(PATCH_RESULT EQUAL 0)
38-
execute_process(
39-
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
40-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
41-
RESULT_VARIABLE PATCH_RESULT
42-
)
43-
else()
44-
execute_process( # Check if the patch is already applied
45-
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
46-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
47-
RESULT_VARIABLE PATCH_RESULT
48-
)
49-
endif()
50-
if(NOT PATCH_RESULT EQUAL 0)
51-
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
52-
endif()
42+
else()
43+
execute_process( # Check if the patch is already applied
44+
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
45+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
46+
RESULT_VARIABLE PATCH_RESULT
47+
)
48+
endif()
49+
if(NOT PATCH_RESULT EQUAL 0)
50+
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
5351
endif()
5452

5553
# FIXME: Don't apply patch when Agama driver is updated to incorporate with the SPV_INTEL_bfloat16_arithmetic extension.

third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class SmallVectorBuffer : public std::streambuf {
107107

108108
static SPIRV::TranslatorOpts getSPIRVOpts() {
109109
SPIRV::TranslatorOpts SPIRVOpts{SPIRV::VersionNumber::SPIRV_1_4};
110-
static constexpr std::array<SPIRV::ExtensionID, 19> AllowedExtensions{
110+
static constexpr std::array<SPIRV::ExtensionID, 18> AllowedExtensions{
111111
SPIRV::ExtensionID::SPV_EXT_shader_atomic_float_add,
112112
SPIRV::ExtensionID::SPV_INTEL_2d_block_io,
113113
SPIRV::ExtensionID::SPV_INTEL_arbitrary_precision_integers,
@@ -124,7 +124,6 @@ static SPIRV::TranslatorOpts getSPIRVOpts() {
124124
SPIRV::ExtensionID::SPV_INTEL_tensor_float32_conversion,
125125
SPIRV::ExtensionID::SPV_INTEL_unstructured_loop_controls,
126126
SPIRV::ExtensionID::SPV_INTEL_vector_compute,
127-
SPIRV::ExtensionID::SPV_KHR_bfloat16,
128127
SPIRV::ExtensionID::SPV_KHR_bit_instructions,
129128
SPIRV::ExtensionID::SPV_KHR_non_semantic_info};
130129

0 commit comments

Comments
 (0)