Skip to content

Commit 91cac59

Browse files
Enable bfloat16 unit tests (#5163)
Fixes #3914 --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 37bc4ad commit 91cac59

File tree

6 files changed

+40
-35
lines changed

6 files changed

+40
-35
lines changed

python/test/unit/language/test_core.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,8 +1352,6 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
13521352
pytest.xfail("Only test atomic bfloat16/float16 ops on GPU")
13531353
if "uint" in dtype_x_str and mode in ["min_neg", "all_neg"]:
13541354
pytest.xfail("uint cannot be negative")
1355-
if is_xpu() and dtype_x_str == 'bfloat16':
1356-
pytest.skip("bfloat16 not yet supported for xpu")
13571355

13581356
n_programs = 5
13591357

@@ -1442,8 +1440,6 @@ def kernel(X):
14421440
for check_return_val in ([True, False] if is_hip() else [True])])
14431441
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device):
14441442
check_type_supported(dtype_x_str, device)
1445-
if is_xpu() and dtype_x_str == 'bfloat16':
1446-
pytest.skip("bfloat16 not yet supported for xpu")
14471443
shape0, shape1 = shape
14481444
# triton kernel
14491445

@@ -1523,8 +1519,6 @@ def torch_to_triton_dtype(t):
15231519
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
15241520
def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device):
15251521
check_type_supported(dtype_x_str, device)
1526-
if is_xpu() and dtype_x_str == 'bfloat16':
1527-
pytest.skip("bfloat16 not yet supported for xpu")
15281522

15291523
@triton.jit
15301524
def kernel(X, val, NUM: tl.constexpr):
@@ -1549,8 +1543,6 @@ def kernel(X, val, NUM: tl.constexpr):
15491543
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
15501544
def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device):
15511545
check_type_supported(dtype_x_str, device)
1552-
if is_xpu() and dtype_x_str == 'bfloat16':
1553-
pytest.skip("bfloat16 not yet supported for xpu")
15541546

15551547
@triton.jit
15561548
def kernel(X, val, NUM: tl.constexpr):
@@ -1587,9 +1579,6 @@ def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas
15871579
if is_interpreter():
15881580
pytest.xfail("not supported in the interpreter")
15891581

1590-
if is_xpu() and dtype_x_str == 'bfloat16':
1591-
pytest.skip("bfloat16 not yet supported for xpu")
1592-
15931582
@triton.jit
15941583
def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr):
15951584
xoffset = tl.program_id(0) * XBLOCK

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,9 +1566,6 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK
15661566
pytest.xfail("Multi-CTA not supported")
15671567
if is_hip_cdna3() and (kind, dtype_str, M_BLOCK, N_BLOCK) in REDUCE_SKIP_HIP_CDNA3:
15681568
pytest.skip("Broken on rocm")
1569-
if is_xpu():
1570-
if (kind, dtype_str) in [("add", "bfloat16")]:
1571-
pytest.skip("FIXME: issue #3914")
15721569

15731570
@triton.jit(debug=True)
15741571
def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr):

scripts/skiplist/lts/language.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4665
22
python/test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float64-float64]
33
python/test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float64-float64]
4+
# Below bfloat16 tests require IGC 1188 or above
5+
python/test/unit/language/test_core.py::test_atomic_rmw[r".*bfloat16.*"]@regexp
6+
python/test/unit/language/test_core.py::test_tensor_atomic_rmw[r".*bfloat16.*"]@regexp
7+
python/test/unit/language/test_core.py::test_tensor_atomic_add_non_exclusive_offset[r".*bfloat16.*"]@regexp
8+
python/test/unit/language/test_core.py::test_tensor_atomic_add_shift_1[r".*bfloat16.*"]@regexp
9+
python/test/unit/language/test_core.py::test_tensor_atomic_add_access_patterns[r".*bfloat16.*"]@regexp
10+
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[r".*bfloat16.*"]@regexp

setup.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,15 @@ def build_extension(self, ext):
481481
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
482482
cmake_args.extend(thirdparty_cmake_args)
483483

484+
result = subprocess.run(["bash", "./scripts/capture-hw-details.sh"], stdout=subprocess.PIPE,
485+
stderr=subprocess.PIPE, check=True, text=True, env=os.environ.copy())
486+
agama_version = None
487+
for line in result.stdout.splitlines():
488+
if line.startswith("AGAMA_VERSION="):
489+
agama_version = line.split("=", 1)[1].strip()
490+
break
491+
cmake_args.append(f"-DAGAMA_VERSION={agama_version}")
492+
484493
# configuration
485494
cfg = get_build_type()
486495
build_args = ["--config", cfg]

third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,30 @@ if (NOT SPIRVToLLVMTranslator_FOUND)
2626

2727
FetchContent_MakeAvailable(spirv-llvm-translator)
2828

29-
# FIXME: Don't apply patch when Agama driver is updated.
30-
execute_process(
31-
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
32-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
33-
ERROR_QUIET
34-
RESULT_VARIABLE PATCH_RESULT
35-
)
36-
if(PATCH_RESULT EQUAL 0)
29+
# FIXME: Don't apply patch when LTS driver is updated.
30+
if(DEFINED AGAMA_VERSION AND AGAMA_VERSION STREQUAL "1146")
3731
execute_process(
38-
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
39-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
40-
RESULT_VARIABLE PATCH_RESULT
32+
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
33+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
34+
ERROR_QUIET
35+
RESULT_VARIABLE PATCH_RESULT
4136
)
42-
else()
43-
execute_process( # Check if the patch is already applied
44-
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
45-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
46-
RESULT_VARIABLE PATCH_RESULT
47-
)
48-
endif()
49-
if(NOT PATCH_RESULT EQUAL 0)
50-
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
37+
if(PATCH_RESULT EQUAL 0)
38+
execute_process(
39+
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
40+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
41+
RESULT_VARIABLE PATCH_RESULT
42+
)
43+
else()
44+
execute_process( # Check if the patch is already applied
45+
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
46+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
47+
RESULT_VARIABLE PATCH_RESULT
48+
)
49+
endif()
50+
if(NOT PATCH_RESULT EQUAL 0)
51+
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
52+
endif()
5153
endif()
5254

5355
# FIXME: Don't apply patch when Agama driver is updated to incorporate with the SPV_INTEL_bfloat16_arithmetic extension.

third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class SmallVectorBuffer : public std::streambuf {
107107

108108
static SPIRV::TranslatorOpts getSPIRVOpts() {
109109
SPIRV::TranslatorOpts SPIRVOpts{SPIRV::VersionNumber::SPIRV_1_4};
110-
static constexpr std::array<SPIRV::ExtensionID, 18> AllowedExtensions{
110+
static constexpr std::array<SPIRV::ExtensionID, 19> AllowedExtensions{
111111
SPIRV::ExtensionID::SPV_EXT_shader_atomic_float_add,
112112
SPIRV::ExtensionID::SPV_INTEL_2d_block_io,
113113
SPIRV::ExtensionID::SPV_INTEL_arbitrary_precision_integers,
@@ -124,6 +124,7 @@ static SPIRV::TranslatorOpts getSPIRVOpts() {
124124
SPIRV::ExtensionID::SPV_INTEL_tensor_float32_conversion,
125125
SPIRV::ExtensionID::SPV_INTEL_unstructured_loop_controls,
126126
SPIRV::ExtensionID::SPV_INTEL_vector_compute,
127+
SPIRV::ExtensionID::SPV_KHR_bfloat16,
127128
SPIRV::ExtensionID::SPV_KHR_bit_instructions,
128129
SPIRV::ExtensionID::SPV_KHR_non_semantic_info};
129130

0 commit comments

Comments
 (0)