Skip to content

Commit 5fd91bf

Browse files
committed
[SYCL][Matrix] Add W/A for several corner cases of AccessChain usage
These corner cases are: 1. AccessChain uses are optimized out of LLVM IR modules, leaving the call unused; 2. AccessChain result is used in GEP 0,0 instruction for bfloat16 (instead of the immidiate use by load or store). All of these issues are or will be fixed in our drivers, but since the cadence of the driver update is relatively big the W/A is added in the frontend for an immediate fix. Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent e94cfda commit 5fd91bf

File tree

4 files changed

+83
-4
lines changed

4 files changed

+83
-4
lines changed

llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,45 @@ namespace {
2222
static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
2323
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";
2424

25-
// This routine extracts spirv.CooperativeMatrixKHR target extension type
26-
// from sycl::joint_matrix class object if it's used in __spirv_AccessChain
27-
// function call. It's necessary because otherwise OpAccessChain indices would
28-
// be wrong.
25+
// This function finds all calls to __spirv_AccessChain function and transforms
26+
// its users and operands to make LLVM IR more SPIR-V friendly.
2927
bool transformAccessChain(Function *F) {
3028
bool ModuleChanged = false;
3129
for (auto I : F->users()) {
3230
auto *CI = dyn_cast<CallInst>(I);
3331
if (!CI)
3432
continue;
33+
34+
// This is a W/A for bfloat16 and tf32 types - they are represented in SYCL
35+
// as structures with int16/float storages. It means, that in LLVM IR
36+
// user of CallInst to __spirv_AccessChain function would be not load/store
37+
// instruction, but a zero GEP. This zero GEP is no-op, but can confuse a
38+
// SPIR-V consumer, so lets remove it here.
39+
auto *Unique = CI->getUniqueUndroppableUser();
40+
if (auto *CastCand = dyn_cast_or_null<Instruction>(Unique)) {
41+
if (auto *GEP = dyn_cast<GetElementPtrInst>(CastCand)) {
42+
if (GEP->hasAllZeroIndices()) {
43+
GEP->replaceAllUsesWith(CI);
44+
GEP->dropAllReferences();
45+
GEP->eraseFromParent();
46+
}
47+
}
48+
}
49+
50+
// It can happen that the optimizer can remove duplicated or dead uses
51+
// of CallInst to __spirv_AccessChain function. But it can't remove
52+
// __spirv_AccessChain call inself as it's a call to external function.
53+
// Lets clean such calls.
54+
if (CI->getNumUses() == 0) {
55+
CI->dropAllReferences();
56+
CI->eraseFromParent();
57+
continue;
58+
}
59+
60+
// This routine extracts spirv.CooperativeMatrixKHR target extension type
61+
// from sycl::joint_matrix class object if it's used in __spirv_AccessChain
62+
// function call. It's necessary because otherwise OpAccessChain indices
63+
// would be wrong.
3564
Instruction *Ptr =
3665
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
3766
if (!Ptr || !isa<AllocaInst>(Ptr))
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; The test checks, that unused call to __spirv_AccessChain is eliminated
2+
3+
; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s
4+
5+
; CHECK-NOT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain
6+
7+
; ModuleID = 'test.bc'
8+
source_filename = "test.cpp"
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
10+
target triple = "spir64-unknown-unknown"
11+
12+
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
13+
14+
define weak_odr dso_local spir_kernel void @test() {
15+
entry:
16+
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
17+
%1 = addrspacecast ptr %0 to ptr addrspace(4)
18+
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
19+
ret void
20+
}
21+
22+
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)

llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ entry:
1919
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
2020
%1 = addrspacecast ptr %0 to ptr addrspace(4)
2121
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
22+
%3 = load i8, ptr addrspace(4) %2
2223
ret void
2324
}
2425

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; Test checks if useless zero GEP to get i16 from sycl::bfloat16 is being removed
2+
3+
; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s
4+
5+
; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i16, 3, 16, 64, 0)
6+
; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
7+
; CHECK: %[[#AC:]] = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)
8+
; CHECK: load i16, ptr addrspace(4) %[[#AC]]
9+
10+
; ModuleID = 'test.bc'
11+
source_filename = "test.cpp"
12+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
13+
target triple = "spir64-unknown-unknown"
14+
15+
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i16, 3, 16, 64, 0) }
16+
17+
define weak_odr dso_local spir_kernel void @test() {
18+
entry:
19+
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
20+
%1 = addrspacecast ptr %0 to ptr addrspace(4)
21+
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
22+
%3 = getelementptr inbounds { i16 }, ptr addrspace(4) %2, i64 0, i32 0
23+
%4 = load i16, ptr addrspace(4) %3
24+
ret void
25+
}
26+
27+
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)

0 commit comments

Comments
 (0)