Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,8 +1352,6 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
pytest.xfail("Only test atomic bfloat16/float16 ops on GPU")
if "uint" in dtype_x_str and mode in ["min_neg", "all_neg"]:
pytest.xfail("uint cannot be negative")
if is_xpu() and dtype_x_str == 'bfloat16':
pytest.skip("bfloat16 not yet supported for xpu")

n_programs = 5

Expand Down Expand Up @@ -1442,8 +1440,6 @@ def kernel(X):
for check_return_val in ([True, False] if is_hip() else [True])])
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device):
check_type_supported(dtype_x_str, device)
if is_xpu() and dtype_x_str == 'bfloat16':
pytest.skip("bfloat16 not yet supported for xpu")
shape0, shape1 = shape
# triton kernel

Expand Down Expand Up @@ -1523,8 +1519,6 @@ def torch_to_triton_dtype(t):
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device):
check_type_supported(dtype_x_str, device)
if is_xpu() and dtype_x_str == 'bfloat16':
pytest.skip("bfloat16 not yet supported for xpu")

@triton.jit
def kernel(X, val, NUM: tl.constexpr):
Expand All @@ -1549,8 +1543,6 @@ def kernel(X, val, NUM: tl.constexpr):
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device):
check_type_supported(dtype_x_str, device)
if is_xpu() and dtype_x_str == 'bfloat16':
pytest.skip("bfloat16 not yet supported for xpu")

@triton.jit
def kernel(X, val, NUM: tl.constexpr):
Expand Down Expand Up @@ -1587,9 +1579,6 @@ def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas
if is_interpreter():
pytest.xfail("not supported in the interpreter")

if is_xpu() and dtype_x_str == 'bfloat16':
pytest.skip("bfloat16 not yet supported for xpu")

@triton.jit
def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
Expand Down
3 changes: 0 additions & 3 deletions python/test/unit/language/test_tensor_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,9 +1566,6 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK
pytest.xfail("Multi-CTA not supported")
if is_hip_cdna3() and (kind, dtype_str, M_BLOCK, N_BLOCK) in REDUCE_SKIP_HIP_CDNA3:
pytest.skip("Broken on rocm")
if is_xpu():
if (kind, dtype_str) in [("add", "bfloat16")]:
pytest.skip("FIXME: issue #3914")

@triton.jit(debug=True)
def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr):
Expand Down
24 changes: 0 additions & 24 deletions third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,6 @@ if (NOT SPIRVToLLVMTranslator_FOUND)

FetchContent_MakeAvailable(spirv-llvm-translator)

# FIXME: Don't apply patch when Agama driver is updated.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing old hack now that we have a new driver.

execute_process(
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
ERROR_QUIET
RESULT_VARIABLE PATCH_RESULT
)
if(PATCH_RESULT EQUAL 0)
execute_process(
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
RESULT_VARIABLE PATCH_RESULT
)
else()
execute_process( # Check if the patch is already applied
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
RESULT_VARIABLE PATCH_RESULT
)
endif()
if(NOT PATCH_RESULT EQUAL 0)
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
endif()

# FIXME: Don't apply patch when Agama driver is updated to incorporate with the SPV_INTEL_bfloat16_arithmetic extension.
execute_process(
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3388.patch
Expand Down
3 changes: 2 additions & 1 deletion third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class SmallVectorBuffer : public std::streambuf {

static SPIRV::TranslatorOpts getSPIRVOpts() {
SPIRV::TranslatorOpts SPIRVOpts{SPIRV::VersionNumber::SPIRV_1_4};
static constexpr std::array<SPIRV::ExtensionID, 18> AllowedExtensions{
static constexpr std::array<SPIRV::ExtensionID, 19> AllowedExtensions{
SPIRV::ExtensionID::SPV_EXT_shader_atomic_float_add,
SPIRV::ExtensionID::SPV_INTEL_2d_block_io,
SPIRV::ExtensionID::SPV_INTEL_arbitrary_precision_integers,
Expand All @@ -124,6 +124,7 @@ static SPIRV::TranslatorOpts getSPIRVOpts() {
SPIRV::ExtensionID::SPV_INTEL_tensor_float32_conversion,
SPIRV::ExtensionID::SPV_INTEL_unstructured_loop_controls,
SPIRV::ExtensionID::SPV_INTEL_vector_compute,
SPIRV::ExtensionID::SPV_KHR_bfloat16,
SPIRV::ExtensionID::SPV_KHR_bit_instructions,
SPIRV::ExtensionID::SPV_KHR_non_semantic_info};

Expand Down
Loading