Skip to content

Commit d16677b

Browse files
Change SPV_KHR_bfloat16 workaround patch to work for different drivers (#5585)
Check the driver version at runtime instead of build time. Fixes #5574 --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 414f7b7 commit d16677b

File tree

4 files changed

+47
-41
lines changed

4 files changed

+47
-41
lines changed

setup.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -489,15 +489,6 @@ def build_extension(self, ext):
489489
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
490490
cmake_args.extend(thirdparty_cmake_args)
491491

492-
result = subprocess.run(["bash", "./scripts/capture-hw-details.sh"], stdout=subprocess.PIPE,
493-
stderr=subprocess.PIPE, check=True, text=True, env=os.environ.copy())
494-
agama_version = None
495-
for line in result.stdout.splitlines():
496-
if line.startswith("AGAMA_VERSION="):
497-
agama_version = line.split("=", 1)[1].strip()
498-
break
499-
cmake_args.append(f"-DAGAMA_VERSION={agama_version}")
500-
501492
# configuration
502493
cfg = get_build_type()
503494
build_args = ["--config", cfg]

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ def make_llir(cls, src, metadata, options):
402402
@classmethod
403403
@track
404404
def make_spv(cls, src, metadata, options):
405+
driver_version = metadata["target"].arch.get("driver_version")
406+
os.environ["INTEL_XPU_BACKEND_DRIVER_VERSION"] = driver_version
405407
spirv, name = intel.translate_to_spirv(src)
406408
metadata["name"] = name
407409
metadata.setdefault("build_flags", "")

third_party/intel/cmake/3122.patch

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,35 @@
11
diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp
2-
index a124ba48c..3f46b5685 100644
2+
index ec4ec41f..b481609f 100644
33
--- a/lib/SPIRV/SPIRVWriter.cpp
44
+++ b/lib/SPIRV/SPIRVWriter.cpp
5-
@@ -397,6 +397,7 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
6-
}
5+
@@ -401,13 +401,23 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
76
}
87

9-
+#if 0
108
if (T->isBFloatTy()) {
11-
BM->getErrorLog().checkError(
12-
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_bfloat16),
13-
@@ -406,6 +407,7 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
14-
"requires this extension");
15-
return mapType(T, BM->addFloatType(16, FPEncodingBFloat16KHR));
9+
- BM->getErrorLog().checkError(
10+
- BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_bfloat16),
11+
- SPIRVEC_RequiresExtension,
12+
- "SPV_KHR_bfloat16\n"
13+
- "NOTE: LLVM module contains bfloat type, translation of which "
14+
- "requires this extension");
15+
- return mapType(T, BM->addFloatType(16, FPEncodingBFloat16KHR));
16+
+ // Workaround for LTS2 driver.
17+
+ const char *driverVersion = std::getenv("INTEL_XPU_BACKEND_DRIVER_VERSION");
18+
+ if (driverVersion) {
19+
+ int v0 = 0, v1 = 0, v2 = 0, v3 = 0;
20+
+ sscanf(driverVersion, "%d.%d.%d+%d", &v0, &v1, &v2, &v3);
21+
+ std::tuple<int, int, int, int> ver = {v0, v1, v2, v3};
22+
+ std::tuple<int, int, int, int> minVer = {1, 6, 35096, 9};
23+
+ if (ver >= minVer) {
24+
+ BM->getErrorLog().checkError(
25+
+ BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_bfloat16),
26+
+ SPIRVEC_RequiresExtension,
27+
+ "SPV_KHR_bfloat16\n"
28+
+ "NOTE: LLVM module contains bfloat type, translation of which "
29+
+ "requires this extension");
30+
+ return mapType(T, BM->addFloatType(16, FPEncodingBFloat16KHR));
31+
+ }
32+
+ }
1633
}
17-
+#endif
1834

1935
if (T->isFloatingPointTy())
20-
return mapType(T, BM->addFloatType(T->getPrimitiveSizeInBits()));

third_party/intel/cmake/FindSPIRVToLLVMTranslator.cmake

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,27 @@ if (NOT SPIRVToLLVMTranslator_FOUND)
2727
FetchContent_MakeAvailable(spirv-llvm-translator)
2828

2929
# FIXME: Don't apply patch when LTS driver is updated.
30-
if(DEFINED AGAMA_VERSION AND AGAMA_VERSION STREQUAL "1146")
30+
execute_process(
31+
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
32+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
33+
ERROR_QUIET
34+
RESULT_VARIABLE PATCH_RESULT
35+
)
36+
if(PATCH_RESULT EQUAL 0)
3137
execute_process(
32-
COMMAND git apply --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
33-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
34-
ERROR_QUIET
35-
RESULT_VARIABLE PATCH_RESULT
38+
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
39+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
40+
RESULT_VARIABLE PATCH_RESULT
3641
)
37-
if(PATCH_RESULT EQUAL 0)
38-
execute_process(
39-
COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/3122.patch
40-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
41-
RESULT_VARIABLE PATCH_RESULT
42-
)
43-
else()
44-
execute_process( # Check if the patch is already applied
45-
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
46-
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
47-
RESULT_VARIABLE PATCH_RESULT
48-
)
49-
endif()
50-
if(NOT PATCH_RESULT EQUAL 0)
51-
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
52-
endif()
42+
else()
43+
execute_process( # Check if the patch is already applied
44+
COMMAND git apply --reverse --check ${CMAKE_CURRENT_LIST_DIR}/3122.patch
45+
WORKING_DIRECTORY ${spirv-llvm-translator_SOURCE_DIR}
46+
RESULT_VARIABLE PATCH_RESULT
47+
)
48+
endif()
49+
if(NOT PATCH_RESULT EQUAL 0)
50+
message(FATAL_ERROR "Failed to apply 3122.patch to SPIRV-LLVM-Translator")
5351
endif()
5452

5553
# FIXME: Don't apply patch when Agama driver is updated to incorporate with the SPV_INTEL_bfloat16_arithmetic extension.

0 commit comments

Comments
 (0)