Skip to content

Commit 481090a

Browse files
authored
[bug] Fix crash when lowering multi-dimension groupshared variable (microsoft#5895)
This commit fixes a crash in the compiler when lowering a groupshared variable with a multi-dimensional array type. The root cause of the bug was that we had a nested gep expression that could not be merged into a single gep because of an intervening addrspacecast. The `MultiDimArrayToOneDimArray` pass flattens the multi-dimension global variables to a single dimension. It relies on the `MergeGepUse` function to flatten any nested geps into a single gep that fully dereferences a scalar element. The fix is to modify the `MergeGepUse` function to look through addrspacecast instructions when trying to merge geps. We can now merge geps like gep(addrspacecast(gep(p0, gep_args0)) to p1*, gep_args1) into addrspacecast(gep(p0, gep_args0+gep_args1) to p1*) We also added a call to `removeDeadConstantUsers` before flattening multi-dimension globals because we can have some dead constants hanging around after merging geps and these constants should be ignored by the flattening pass.
1 parent 0a5396c commit 481090a

File tree

5 files changed

+347
-17
lines changed

5 files changed

+347
-17
lines changed

include/llvm/IR/Operator.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,26 @@ class BitCastOperator
501501
}
502502
};
503503

504+
// HLSL CHANGE: Add this helper class from upstream.
505+
class AddrSpaceCastOperator
506+
: public ConcreteOperator<Operator, Instruction::AddrSpaceCast> {
507+
friend class AddrSpaceCastInst;
508+
friend class ConstantExpr;
509+
510+
public:
511+
Value *getPointerOperand() { return getOperand(0); }
512+
513+
const Value *getPointerOperand() const { return getOperand(0); }
514+
515+
unsigned getSrcAddressSpace() const {
516+
return getPointerOperand()->getType()->getPointerAddressSpace();
517+
}
518+
519+
unsigned getDestAddressSpace() const {
520+
return getType()->getPointerAddressSpace();
521+
}
522+
};
523+
504524
} // End llvm namespace
505525

506526
#endif

lib/DXIL/DxilUtilDbgInfoAndMisc.cpp

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,18 @@ using namespace hlsl;
3737

3838
namespace {
3939

40-
Value *MergeGEP(GEPOperator *SrcGEP, GEPOperator *GEP) {
40+
// Attempt to merge the two GEPs into a single GEP.
41+
//
42+
// If `AsCast` is non-null the merged GEP will be wrapped
43+
// in an addrspacecast before replacing users. This allows
44+
// merging GEPs of the form
45+
//
46+
// gep(addrspacecast(gep(p0, gep_args0) to p1*), gep_args1)
47+
// into
48+
// addrspacecast(gep(p0, gep_args0+gep_args1) to p1*)
49+
//
50+
Value *MergeGEP(GEPOperator *SrcGEP, GEPOperator *GEP,
51+
AddrSpaceCastOperator *AsCast) {
4152
IRBuilder<> Builder(GEP->getContext());
4253
StringRef Name = "";
4354
if (Instruction *I = dyn_cast<Instruction>(GEP)) {
@@ -75,7 +86,7 @@ Value *MergeGEP(GEPOperator *SrcGEP, GEPOperator *GEP) {
7586
}
7687

7788
// Update the GEP in place if possible.
78-
if (SrcGEP->getNumOperands() == 2) {
89+
if (SrcGEP->getNumOperands() == 2 && !AsCast) {
7990
GEP->setOperand(0, SrcGEP->getOperand(0));
8091
GEP->setOperand(1, Sum);
8192
return GEP;
@@ -94,12 +105,64 @@ Value *MergeGEP(GEPOperator *SrcGEP, GEPOperator *GEP) {
94105
DXASSERT(!Indices.empty(), "must merge");
95106
Value *newGEP =
96107
Builder.CreateInBoundsGEP(nullptr, SrcGEP->getOperand(0), Indices, Name);
108+
109+
// Wrap the new gep in an addrspacecast if needed.
110+
if (AsCast)
111+
newGEP = Builder.CreateAddrSpaceCast(
112+
newGEP, PointerType::get(GEP->getType()->getPointerElementType(),
113+
AsCast->getDestAddressSpace()));
97114
GEP->replaceAllUsesWith(newGEP);
98115
if (Instruction *I = dyn_cast<Instruction>(GEP))
99116
I->eraseFromParent();
100117
return newGEP;
101118
}
102119

120+
// Examine the gep and try to merge it when the input pointer is
121+
// itself a gep. We handle two forms here:
122+
//
123+
// gep(gep(p))
124+
// gep(addrspacecast(gep(p)))
125+
//
126+
// If the gep was merged successfully then return the updated value, otherwise
127+
// return nullptr.
128+
//
129+
// When the gep is sucessfully merged we will delete the gep and also try to
130+
// delete the nested gep and addrspacecast.
131+
static Value *TryMegeWithNestedGEP(GEPOperator *GEP) {
132+
// Sentinal value to return when we fail to merge.
133+
Value *FailedToMerge = nullptr;
134+
135+
Value *Ptr = GEP->getPointerOperand();
136+
GEPOperator *prevGEP = dyn_cast<GEPOperator>(Ptr);
137+
AddrSpaceCastOperator *AsCast = nullptr;
138+
139+
// If there is no directly nested gep try looking through an addrspacecast to
140+
// find one.
141+
if (!prevGEP) {
142+
AsCast = dyn_cast<AddrSpaceCastOperator>(Ptr);
143+
if (AsCast)
144+
prevGEP = dyn_cast<GEPOperator>(AsCast->getPointerOperand());
145+
}
146+
147+
// Not a nested gep expression.
148+
if (!prevGEP)
149+
return FailedToMerge;
150+
151+
// Try merging the two geps.
152+
Value *newGEP = MergeGEP(prevGEP, GEP, AsCast);
153+
if (!newGEP)
154+
return FailedToMerge;
155+
156+
// Delete the nested gep and addrspacecast if no more users.
157+
if (AsCast && AsCast->user_empty() && isa<AddrSpaceCastInst>(AsCast))
158+
cast<AddrSpaceCastInst>(AsCast)->eraseFromParent();
159+
160+
if (prevGEP->user_empty() && isa<GetElementPtrInst>(prevGEP))
161+
cast<GetElementPtrInst>(prevGEP)->eraseFromParent();
162+
163+
return newGEP;
164+
}
165+
103166
} // namespace
104167

105168
namespace hlsl {
@@ -130,23 +193,14 @@ bool MergeGepUse(Value *V) {
130193
// merge any GEP users of the untranslated bitcast
131194
addUsersToWorklist(V);
132195
}
196+
} else if (isa<AddrSpaceCastOperator>(V)) {
197+
addUsersToWorklist(V);
133198
} else if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
134-
if (GEPOperator *prevGEP =
135-
dyn_cast<GEPOperator>(GEP->getPointerOperand())) {
136-
// merge the 2 GEPs, returns nullptr if couldn't merge
137-
if (Value *newGEP = MergeGEP(prevGEP, GEP)) {
138-
changed = true;
139-
worklist.push_back(newGEP);
140-
// delete prevGEP if no more users
141-
if (prevGEP->user_empty() && isa<GetElementPtrInst>(prevGEP)) {
142-
cast<GetElementPtrInst>(prevGEP)->eraseFromParent();
143-
}
144-
} else {
145-
addUsersToWorklist(GEP);
146-
}
199+
if (Value *newGEP = TryMegeWithNestedGEP(GEP)) {
200+
changed = true;
201+
worklist.push_back(newGEP);
147202
} else {
148-
// nothing to merge yet, add GEP users
149-
addUsersToWorklist(V);
203+
addUsersToWorklist(GEP);
150204
}
151205
}
152206
}

lib/Transforms/Scalar/LowerTypePasses.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ bool LowerTypePass::runOnModule(Module &M) {
161161
HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, NewGV);
162162
}
163163
// Replace users.
164+
GV->removeDeadConstantUsers();
164165
lowerUseWithNewValue(GV, NewGV);
165166
// Remove GV.
166167
GV->removeDeadConstantUsers();
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
; RUN: opt -S -multi-dim-one-dim %s | FileCheck %s
2+
;
3+
; Tests for the pass that changes multi-dimension global variable accesses into
4+
; a flattened one-dimensional access. The tests focus on the case where the geps
5+
; need to be merged but are separated by an addrspacecast operation. This was
6+
; causing the pass to fail because it could not merge the gep through the
7+
; addrspace cast.
8+
9+
; Naming convention: gep0_addrspacecast_gep1
10+
11+
target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64"
12+
target triple = "dxil-ms-dx"
13+
14+
@ArrayOfArray = addrspace(3) global [256 x [9 x float]] undef, align 4
15+
@ArrayOfArrayOfArray = addrspace(3) global [256 x [9 x [3 x float]]] undef, align 4
16+
17+
; Test that we can merge the geps when all parts are instructions.
18+
; CHECK-LABEL: @merge_gep_instr_instr_instr
19+
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 1) to float*)
20+
define void @merge_gep_instr_instr_instr() {
21+
entry:
22+
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 0
23+
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [9 x float]*
24+
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 1
25+
%load = load float, float* %gep1
26+
ret void
27+
}
28+
29+
; Test that we can merge the geps when the inner gep are constants.
30+
; CHECK-LABEL: @merge_gep_instr_instr_const
31+
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 1) to float*)
32+
define void @merge_gep_instr_instr_const() {
33+
entry:
34+
%asc = addrspacecast [9 x float] addrspace(3)* getelementptr inbounds ([256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 0) to [9 x float]*
35+
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 1
36+
%load = load float, float* %gep1
37+
ret void
38+
}
39+
40+
; Test that we can merge the geps when the addrspace and inner gep are constants.
41+
; CHECK-LABEL: @merge_gep_instr_const_const
42+
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 1) to float*)
43+
define void @merge_gep_instr_const_const() {
44+
entry:
45+
%gep1 = getelementptr inbounds [9 x float], [9 x float]* addrspacecast ([9 x float] addrspace(3)* getelementptr inbounds ([256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 0) to [9 x float]*), i32 0, i32 1
46+
%load = load float, float* %gep1
47+
ret void
48+
}
49+
50+
; Test that we can merge the geps when all parts are constants.
51+
; CHECK-LABEL: @merge_gep_const_const
52+
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 1) to float*)
53+
define void @merge_gep_const_const_const() {
54+
entry:
55+
%load = load float, float* getelementptr inbounds ([9 x float], [9 x float]* addrspacecast ([9 x float] addrspace(3)* getelementptr inbounds ([256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 0) to [9 x float]*), i32 0, i32 1)
56+
ret void
57+
}
58+
59+
; Test that we compute the correct index when the outer array has
60+
; a non-zero constant index.
61+
; CHECK-LABEL: @merge_gep_const_outer_array_index
62+
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 66) to float*)
63+
define void @merge_gep_const_outer_array_index() {
64+
entry:
65+
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 7
66+
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [9 x float]*
67+
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 3
68+
%load = load float, float* %gep1
69+
ret void
70+
}
71+
72+
; Test that we compute the correct index when the outer array has
73+
; a non-constant index.
74+
; CHECK-LABEL: @merge_gep_dynamic_outer_array_index
75+
; CHECK: %0 = mul i32 %idx, 9
76+
; CHECK: %1 = add i32 3, %0
77+
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
78+
; CHECK: %3 = addrspacecast float addrspace(3)* %2 to float*
79+
; CHECK: load float, float* %3
80+
define void @merge_gep_dynamic_outer_array_index(i32 %idx) {
81+
entry:
82+
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 %idx
83+
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [9 x float]*
84+
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 3
85+
%load = load float, float* %gep1
86+
ret void
87+
}
88+
89+
; Test that we compute the correct index when the both arrays have
90+
; a non-constant index.
91+
; CHECK-LABEL: @merge_gep_dynamic_array_index
92+
; CHECK: %0 = mul i32 %idx0, 9
93+
; CHECK: %1 = add i32 %idx1, %0
94+
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
95+
; CHECK: %3 = addrspacecast float addrspace(3)* %2 to float*
96+
; CHECK: load float, float* %3
97+
define void @merge_gep_dynamic_array_index(i32 %idx0, i32 %idx1) {
98+
entry:
99+
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 %idx0
100+
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [9 x float]*
101+
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 %idx1
102+
%load = load float, float* %gep1
103+
ret void
104+
}
105+
106+
; Test that we compute the correct index when there are multiple
107+
; geps after the addrspacecast. This also exercises the case
108+
; where one of the outer geps ends in an array which hits
109+
; an early return in MergeGEP.
110+
; CHECK-LABEL: @merge_gep_multi_level_end_in_sequential_with_addrspace
111+
; CHECK: %0 = mul i32 %idx0, 9
112+
; CHECK: %1 = add i32 %idx1, %0
113+
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
114+
; CHECK: %3 = addrspacecast float addrspace(3)* %2 to float*
115+
; CHECK: load float, float* %3
116+
define void @merge_gep_multi_level_end_in_sequential_with_addrspace(i32 %idx0, i32 %idx1) {
117+
entry:
118+
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0
119+
%asc = addrspacecast [256 x [9 x float]] addrspace(3)* %gep0 to [256 x [9 x float]]*
120+
%gep1 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]]* %asc, i32 0, i32 %idx0
121+
%gep2 = getelementptr inbounds [9 x float], [9 x float]* %gep1, i32 0, i32 %idx1
122+
%load = load float, float* %gep2
123+
ret void
124+
}
125+
126+
; Test that we compute the correct index when there are three levels of geps.
127+
; This also exercises the case where one of the outer geps ends in an
128+
; array which hits an early return in MergeGEP.
129+
; CHECK-LABEL: @merge_gep_multi_level_end_in_sequential
130+
; CHECK: %0 = mul i32 %idx0, 9
131+
; CHECK: %1 = add i32 %idx1, %0
132+
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
133+
; CHECK: load float, float addrspace(3)* %2
134+
define void @merge_gep_multi_level_end_in_sequential(i32 %idx0, i32 %idx1) {
135+
entry:
136+
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0
137+
%gep1 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* %gep0, i32 0, i32 %idx0
138+
%gep2 = getelementptr inbounds [9 x float], [9 x float] addrspace(3)* %gep1, i32 0, i32 %idx1
139+
%load = load float, float addrspace(3)* %gep2
140+
ret void
141+
}
142+
143+
; Test that we compute the correct index when the global has 3 levels of
144+
; nested arrays and an addrspacecast.
145+
; CHECK-LABEL: @merge_gep_multi_level_with_addrspace
146+
; CHECK: %0 = mul i32 %idx0, 9
147+
; CHECK: %1 = add i32 %idx1, %0
148+
; CHECK: %2 = mul i32 %1, 3
149+
; CHECK: %3 = add i32 %idx2, %2
150+
; CHECK: %4 = getelementptr [6912 x float], [6912 x float] addrspace(3)* @ArrayOfArrayOfArray.1dim, i32 0, i32 %3
151+
; CHECK: %5 = addrspacecast float addrspace(3)* %4 to float*
152+
; CHECK: load float, float* %5
153+
define void @merge_gep_multi_level_with_addrspace(i32 %idx0, i32 %idx1, i32 %idx2) {
154+
entry:
155+
%gep0 = getelementptr inbounds [256 x [9 x [3 x float]]], [256 x [9 x [3 x float]]] addrspace(3)* @ArrayOfArrayOfArray, i32 0, i32 %idx0
156+
%asc = addrspacecast [9 x [3 x float]] addrspace(3)* %gep0 to [9 x [3 x float]]*
157+
%gep1 = getelementptr inbounds [9 x [3 x float]], [9 x [3 x float]]* %asc, i32 0, i32 %idx1
158+
%gep2 = getelementptr inbounds [3 x float], [3 x float]* %gep1, i32 0, i32 %idx2
159+
%load = load float, float* %gep2
160+
ret void
161+
}
162+
163+
; Test that we compute the correct index when the global has 3 levels of
164+
; nested arrays.
165+
; CHECK-LABEL: @merge_gep_multi_level
166+
; CHECK: %0 = mul i32 %idx0, 9
167+
; CHECK: %1 = add i32 %idx1, %0
168+
; CHECK: %2 = mul i32 %1, 3
169+
; CHECK: %3 = add i32 %idx2, %2
170+
; CHECK: %4 = getelementptr [6912 x float], [6912 x float] addrspace(3)* @ArrayOfArrayOfArray.1dim, i32 0, i32 %3
171+
; CHECK: load float, float addrspace(3)* %4
172+
define void @merge_gep_multi_level(i32 %idx0, i32 %idx1, i32 %idx2) {
173+
entry:
174+
%gep0 = getelementptr inbounds [256 x [9 x [3 x float]]], [256 x [9 x [3 x float]]] addrspace(3)* @ArrayOfArrayOfArray, i32 0, i32 %idx0
175+
%gep1 = getelementptr inbounds [9 x [3 x float]], [9 x [3 x float]] addrspace(3)* %gep0, i32 0, i32 %idx1
176+
%gep2 = getelementptr inbounds [3 x float], [3 x float] addrspace(3)* %gep1, i32 0, i32 %idx2
177+
%load = load float, float addrspace(3)* %gep2
178+
ret void
179+
}
180+
181+
; Test that we compute the correct index when the addrspacecast includes both a
182+
; change in address space and a change in the underlying type. I did not see
183+
; this pattern in IR generated from hlsl, but we can handle this case so I am
184+
; adding a test for it anyway.
185+
; CHECK-LABEL: addrspace_cast_new_type
186+
; CHECK: %0 = mul i32 %idx0, 9
187+
; CHECK: %1 = add i32 %idx1, %0
188+
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
189+
; CHECK: %3 = addrspacecast float addrspace(3)* %2 to i32*
190+
; CHECK: load i32, i32* %3
191+
define void @addrspace_cast_new_type(i32 %idx0, i32 %idx1) {
192+
entry:
193+
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 %idx0
194+
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [3 x i32]*
195+
%gep1 = getelementptr inbounds [3 x i32], [3 x i32]* %asc, i32 0, i32 %idx1
196+
%load = load i32, i32* %gep1
197+
ret void
198+
}

0 commit comments

Comments
 (0)