Skip to content

Commit 57dbccd

Browse files
committed
[SYCL][Matrix] Extend W/A for more corner cases of AccessChain usage
The new corner case is: AccessChain is used on arrays of Joint Matrices and in loops
1 parent 2824f61 commit 57dbccd

File tree

2 files changed

+166
-29
lines changed

2 files changed

+166
-29
lines changed

llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp

Lines changed: 111 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,73 @@ namespace {
2222
static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
2323
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";
2424

25+
Type *getInnermostType(Type *Ty) {
26+
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
27+
Ty = ArrayTy->getElementType();
28+
return Ty;
29+
}
30+
31+
Type *replaceInnermostType(Type *Ty, Type *NewInnermostTy) {
32+
if (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
33+
return ArrayType::get(
34+
replaceInnermostType(ArrayTy->getElementType(), NewInnermostTy),
35+
ArrayTy->getNumElements());
36+
return NewInnermostTy;
37+
}
38+
39+
// This function is a copy of llvm::stripPointerCastsAndOffsets,
40+
// simplified and modified to strip non-zero GEP indices as well and also
41+
// find nearest GEP instruction.
42+
Value *stripPointerCastsAndOffsets(Value *V, bool StopOnGEP = false) {
43+
if (!V->getType()->isPointerTy())
44+
return V;
45+
46+
// Even though we don't look through PHI nodes, we could be called on an
47+
// instruction in an unreachable block, which may be on a cycle.
48+
SmallPtrSet<Value *, 4> Visited;
49+
50+
Visited.insert(V);
51+
do {
52+
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
53+
if (StopOnGEP && isa<GetElementPtrInst>(GEP))
54+
return V;
55+
V = GEP->getPointerOperand();
56+
} else if (Operator::getOpcode(V) == Instruction::BitCast) {
57+
Value *NewV = cast<Operator>(V)->getOperand(0);
58+
if (!NewV->getType()->isPointerTy())
59+
return V;
60+
V = NewV;
61+
} else if (Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
62+
V = cast<Operator>(V)->getOperand(0);
63+
} else {
64+
if (auto *Call = dyn_cast<CallBase>(V)) {
65+
if (Value *RV = Call->getReturnedArgOperand()) {
66+
V = RV;
67+
continue;
68+
}
69+
}
70+
return V;
71+
}
72+
assert(V->getType()->isPointerTy() && "Unexpected operand type!");
73+
} while (Visited.insert(V).second);
74+
75+
return V;
76+
}
77+
78+
TargetExtType *extractMatrixType(StructType *WrapperMatrixTy) {
79+
if (!WrapperMatrixTy)
80+
return nullptr;
81+
TargetExtType *MatrixTy =
82+
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));
83+
84+
if (!MatrixTy)
85+
return nullptr;
86+
StringRef Name = MatrixTy->getName();
87+
if (Name != MATRIX_TYPE)
88+
return nullptr;
89+
return MatrixTy;
90+
}
91+
2592
// This function finds all calls to __spirv_AccessChain function and transforms
2693
// its users and operands to make LLVM IR more SPIR-V friendly.
2794
bool transformAccessChain(Function *F) {
@@ -60,34 +127,59 @@ bool transformAccessChain(Function *F) {
60127
// from sycl::joint_matrix class object if it's used in __spirv_AccessChain
61128
// function call. It's necessary because otherwise OpAccessChain indices
62129
// would be wrong.
63-
Instruction *Ptr =
64-
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
130+
Instruction *Ptr = dyn_cast<Instruction>(
131+
stripPointerCastsAndOffsets(CI->getArgOperand(0)));
65132
if (!Ptr || !isa<AllocaInst>(Ptr))
66133
continue;
67-
StructType *WrapperMatrixTy =
68-
dyn_cast<StructType>(cast<AllocaInst>(Ptr)->getAllocatedType());
69-
if (!WrapperMatrixTy)
70-
continue;
71-
TargetExtType *MatrixTy =
72-
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));
73-
if (!MatrixTy)
134+
135+
Type *AllocaTy = cast<AllocaInst>(Ptr)->getAllocatedType();
136+
// It may happen that sycl::joint_matrix class object is wrapped into
137+
// nested arrays. We need to find the innermost type to extract
138+
if (StructType *WrapperMatrixTy =
139+
dyn_cast<StructType>(getInnermostType(AllocaTy))) {
140+
TargetExtType *MatrixTy = extractMatrixType(WrapperMatrixTy);
141+
if (!MatrixTy)
142+
continue;
143+
144+
AllocaInst *Alloca = nullptr;
145+
{
146+
IRBuilder Builder(CI);
147+
IRBuilderBase::InsertPointGuard IG(Builder);
148+
Builder.SetInsertPointPastAllocas(CI->getFunction());
149+
Alloca = Builder.CreateAlloca(replaceInnermostType(AllocaTy, MatrixTy));
150+
Alloca->takeName(Ptr);
151+
}
152+
Ptr->replaceAllUsesWith(Alloca);
153+
Ptr->dropAllReferences();
154+
Ptr->eraseFromParent();
155+
ModuleChanged = true;
156+
}
157+
158+
// In case spirv.CooperativeMatrixKHR is used in arrays, we also need to
159+
// insert GEP to get pointer to target exention type and use it instead of
160+
// pointer to sycl::joint_matrix class object when it is passed to
161+
// __spirv_AccessChain
162+
// First we check if the argument came from a GEP instruction
163+
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(
164+
stripPointerCastsAndOffsets(CI->getArgOperand(0), true));
165+
if (!GEP)
74166
continue;
75-
StringRef Name = MatrixTy->getName();
76-
if (Name != MATRIX_TYPE)
167+
168+
// Check if GEP return type is a pointer to sycl::joint_matrix class object
169+
StructType *WrapperMatrixTy = dyn_cast<StructType>(GEP->getResultElementType());
170+
if (!extractMatrixType(WrapperMatrixTy))
77171
continue;
78172

79-
AllocaInst *Alloca = nullptr;
173+
// Insert GEP right before the __spirv_AccessChain call
80174
{
81175
IRBuilder Builder(CI);
82-
IRBuilderBase::InsertPointGuard IG(Builder);
83-
Builder.SetInsertPointPastAllocas(CI->getFunction());
84-
Alloca = Builder.CreateAlloca(MatrixTy);
176+
Value *NewGEP = Builder.CreateInBoundsGEP(WrapperMatrixTy,
177+
CI->getArgOperand(0), {Builder.getInt64(0), Builder.getInt32(0)});
178+
CI->setArgOperand(0, NewGEP);
179+
ModuleChanged = true;
85180
}
86-
Ptr->replaceAllUsesWith(Alloca);
87-
Ptr->dropAllReferences();
88-
Ptr->eraseFromParent();
89-
ModuleChanged = true;
90181
}
182+
91183
return ModuleChanged;
92184
}
93185
} // namespace

llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,69 @@
33

44
; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s
55

6-
; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
7-
; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
8-
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)
9-
106
; ModuleID = 'test.bc'
117
source_filename = "test.cpp"
128
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
139
target triple = "spir64-unknown-unknown"
1410

15-
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
11+
%"struct.sycl::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
12+
%"struct.sycl::_V1::long" = type { i64 }
13+
14+
define weak_odr dso_local spir_kernel void @test(i64 %ind) {
15+
; CHECK-LABEL: define weak_odr dso_local spir_kernel void @test(
16+
; CHECK-SAME: i64 [[IND:%.*]]) {
17+
18+
; non-matrix alloca not touched
19+
; CHECK: [[NOT_MATR:%.*]] = alloca [2 x [4 x %"struct.sycl::_V1::long"]]
20+
; both matrix-related allocas updated to use target extension types
21+
; CHECK-NEXT: [[MATR:%.*]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
22+
; CHECK-NEXT: [[MATR_ARR:%.*]] = alloca [2 x [4 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]
23+
24+
; CHECK-NEXT: [[ASCAST:%.*]] = addrspacecast ptr [[MATR]] to ptr addrspace(4)
25+
; no gep inserted, since not needed
26+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST]], i64 noundef 0)
27+
28+
; CHECK: [[GEP:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr [[MATR_ARR]], i64 0, i64 [[IND]], i64 [[IND]]
29+
; CHECK-NEXT: [[ASCAST_1:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
30+
; CHECK-NEXT: [[ASCAST_2:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
31+
; gep is inserted for each of the accesschain calls to extract target extension type
32+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_1]], i64 0, i32 0
33+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP2]], i64 noundef 0)
34+
; CHECK: [[TMP5:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_2]], i64 0, i32 0
35+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0)
36+
37+
; negative test - not touching non-matrix code
38+
; CHECK: [[GEP_1:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr [[NOT_MATR]], i64 0, i64 [[IND]], i64 [[IND]]
39+
; CHECK-NEXT: [[ASCAST_3:%.*]] = addrspacecast ptr [[GEP_1]] to ptr addrspace(4)
40+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST_3]], i64 noundef 0)
1641

17-
define weak_odr dso_local spir_kernel void @test() {
1842
entry:
19-
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
20-
%1 = addrspacecast ptr %0 to ptr addrspace(4)
21-
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
43+
; allocas
44+
%matr = alloca %"struct.sycl::joint_matrix", align 8
45+
%matr.arr = alloca [2 x [4 x %"struct.sycl::joint_matrix"]], align 8
46+
%not.matr = alloca [2 x [4 x %"struct.sycl::_V1::long"]], align 8
47+
48+
; simple case
49+
%ascast = addrspacecast ptr %matr to ptr addrspace(4)
50+
%0 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast, i64 noundef 0)
51+
%1 = load i8, ptr addrspace(4) %0
52+
53+
; gep with non-zero inidices and multiple access chains per 1 alloca
54+
%gep = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr %matr.arr, i64 0, i64 %ind, i64 %ind
55+
%ascast.1 = addrspacecast ptr %gep to ptr addrspace(4)
56+
%ascast.2 = addrspacecast ptr %gep to ptr addrspace(4)
57+
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.1, i64 noundef 0)
2258
%3 = load i8, ptr addrspace(4) %2
59+
%4 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.2, i64 noundef 0)
60+
%5 = load i8, ptr addrspace(4) %4
61+
62+
; negative test - not touching non-matrix code
63+
%gep.1 = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr %not.matr, i64 0, i64 %ind, i64 %ind
64+
%ascast.3 = addrspacecast ptr %gep.1 to ptr addrspace(4)
65+
%6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.3, i64 noundef 0)
66+
%7 = load i8, ptr addrspace(4) %6
67+
2368
ret void
2469
}
2570

26-
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)
71+
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef, i64 noundef)

0 commit comments

Comments
 (0)