diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5f758381f6..24c19ab42c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1352,8 +1352,6 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device): pytest.xfail("Only test atomic bfloat16/float16 ops on GPU") if "uint" in dtype_x_str and mode in ["min_neg", "all_neg"]: pytest.xfail("uint cannot be negative") - if is_xpu() and dtype_x_str == 'bfloat16': - pytest.skip("bfloat16 not yet supported for xpu") n_programs = 5 @@ -1442,8 +1440,6 @@ def kernel(X): for check_return_val in ([True, False] if is_hip() else [True])]) def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device): check_type_supported(dtype_x_str, device) - if is_xpu() and dtype_x_str == 'bfloat16': - pytest.skip("bfloat16 not yet supported for xpu") shape0, shape1 = shape # triton kernel @@ -1523,8 +1519,6 @@ def torch_to_triton_dtype(t): for dtype_x_str in ['bfloat16', 'float16', 'float32']]) def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device): check_type_supported(dtype_x_str, device) - if is_xpu() and dtype_x_str == 'bfloat16': - pytest.skip("bfloat16 not yet supported for xpu") @triton.jit def kernel(X, val, NUM: tl.constexpr): @@ -1549,8 +1543,6 @@ def kernel(X, val, NUM: tl.constexpr): for dtype_x_str in ['bfloat16', 'float16', 'float32']]) def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device): check_type_supported(dtype_x_str, device) - if is_xpu() and dtype_x_str == 'bfloat16': - pytest.skip("bfloat16 not yet supported for xpu") @triton.jit def kernel(X, val, NUM: tl.constexpr): @@ -1587,9 +1579,6 @@ def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas if is_interpreter(): pytest.xfail("not supported in the interpreter") - if is_xpu() and dtype_x_str == 'bfloat16': - pytest.skip("bfloat16 not yet supported for xpu") - @triton.jit def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr): xoffset = tl.program_id(0) * XBLOCK diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index 70448f05ca..40a5d71996 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -1566,9 +1566,6 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK pytest.xfail("Multi-CTA not supported") if is_hip_cdna3() and (kind, dtype_str, M_BLOCK, N_BLOCK) in REDUCE_SKIP_HIP_CDNA3: pytest.skip("Broken on rocm") - if is_xpu(): - if (kind, dtype_str) in [("add", "bfloat16")]: - pytest.skip("FIXME: issue #3914") @triton.jit(debug=True) def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr): diff --git a/third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake b/third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake index c920862539..7053842cad 100644 --- a/third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake +++ b/third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake @@ -26,30 +26,6 @@ if (NOT SPIRVToLLVMTranslator_FOUND) FetchContent_MakeAvailable(spirv-llvm-translator) - # FIXME: Don't apply patch when Agama driver is updated. - execute_process( - COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch - WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR} - ERROR_QUIET - RESULT_VARIABLE PATCH_RESULT - ) - if(PATCH_RESULT EQUAL 0) - execute_process( - COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch - WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR} - RESULT_VARIABLE PATCH_RESULT - ) - else() - execute_process( # Check if the patch is already applied - COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch - WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR} - RESULT_VARIABLE PATCH_RESULT - ) - endif() - if(NOT PATCH_RESULT EQUAL 0) - message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator") - endif() - # FIXME: Don't apply patch when Agama driver is updated to incorporate with the SPV_INTEL_bfloat16_arithmetic extension. execute_process( COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3388.patch diff --git a/third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp b/third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp index 8698fc0b42..7d58801e9c 100644 --- a/third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp +++ b/third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp @@ -107,7 +107,7 @@ class SmallVectorBuffer : public std::streambuf { static SPIRV::TranslatorOpts getSPIRVOpts() { SPIRV::TranslatorOpts SPIRVOpts{SPIRV::VersionNumber::SPIRV_1_4}; - static constexpr std::array AllowedExtensions{ + static constexpr std::array AllowedExtensions{ SPIRV::ExtensionID::SPV_EXT_shader_atomic_float_add, SPIRV::ExtensionID::SPV_INTEL_2d_block_io, SPIRV::ExtensionID::SPV_INTEL_arbitrary_precision_integers, @@ -124,6 +124,7 @@ static SPIRV::TranslatorOpts getSPIRVOpts() { SPIRV::ExtensionID::SPV_INTEL_tensor_float32_conversion, SPIRV::ExtensionID::SPV_INTEL_unstructured_loop_controls, SPIRV::ExtensionID::SPV_INTEL_vector_compute, + SPIRV::ExtensionID::SPV_KHR_bfloat16, SPIRV::ExtensionID::SPV_KHR_bit_instructions, SPIRV::ExtensionID::SPV_KHR_non_semantic_info};