Skip to content

Commit 823125b

Browse files
authored
[DXIL generation] Merge GepUse last to avoid crash in EmitGetNodeRecordPtrAndUpdateUsers (microsoft#6314)
In EmitGetNodeRecordPtrAndUpdateUsers, the type will be mutated. And the GEP user of the RecordPtr will be merged at same time. This make things complex because the GEP index need to be updated since type is mutated. To make things easier, merge the GepUse after mutate type. Fixes microsoft#6223
1 parent df588be commit 823125b

File tree

3 files changed

+87
-6
lines changed

3 files changed

+87
-6
lines changed

lib/HLSL/HLLowerUDT.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ hlsl::TranslateInitForLoweredUDT(Constant *Init, Type *NewTy,
178178
return Init;
179179
}
180180

181-
void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
181+
static void ReplaceUsesForLoweredUDTImpl(Value *V, Value *NewV) {
182182
Type *Ty = V->getType();
183183
Type *NewTy = NewV->getType();
184184

@@ -255,23 +255,22 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
255255
IRBuilder<> Builder(GEP);
256256
SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
257257
Value *NewGEP = Builder.CreateGEP(NewV, idxList);
258-
ReplaceUsesForLoweredUDT(GEP, NewGEP);
259-
dxilutil::MergeGepUse(NewGEP);
258+
ReplaceUsesForLoweredUDTImpl(GEP, NewGEP);
260259
GEP->eraseFromParent();
261260

262261
} else if (GEPOperator *GEP = dyn_cast<GEPOperator>(user)) {
263262
// Has to be constant GEP, NewV better be constant
264263
SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
265264
Constant *NewGEP = ConstantExpr::getGetElementPtr(
266265
nullptr, cast<Constant>(NewV), idxList, true);
267-
ReplaceUsesForLoweredUDT(GEP, NewGEP);
266+
ReplaceUsesForLoweredUDTImpl(GEP, NewGEP);
268267

269268
} else if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(user)) {
270269
// Address space cast
271270
IRBuilder<> Builder(AC);
272271
Value *NewAC = Builder.CreateAddrSpaceCast(
273272
NewV, PointerType::get(Ty, AC->getType()->getPointerAddressSpace()));
274-
ReplaceUsesForLoweredUDT(user, NewAC);
273+
ReplaceUsesForLoweredUDTImpl(user, NewAC);
275274
AC->eraseFromParent();
276275
} else if (BitCastInst *BC = dyn_cast<BitCastInst>(user)) {
277276
IRBuilder<> Builder(BC);
@@ -295,7 +294,7 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
295294
Constant *NewAC = ConstantExpr::getAddrSpaceCast(
296295
cast<Constant>(NewV),
297296
PointerType::get(Ty, CE->getType()->getPointerAddressSpace()));
298-
ReplaceUsesForLoweredUDT(user, NewAC);
297+
ReplaceUsesForLoweredUDTImpl(user, NewAC);
299298
} else if (CE->getOpcode() == Instruction::BitCast) {
300299
if (CE->getType()->getPointerElementType() == NewTy) {
301300
// if alreday bitcast to new type, just replace the bitcast
@@ -475,3 +474,9 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
475474
CV->removeDeadConstantUsers();
476475
}
477476
}
477+
478+
void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
479+
ReplaceUsesForLoweredUDTImpl(V, NewV);
480+
// Merge GepUse later to avoid mutate type and merge gep use at same time.
481+
dxilutil::MergeGepUse(NewV);
482+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %dxc -Tlib_6_8 %s | FileCheck %s
2+
3+
// Make sure generate correct metadata for Entry.
4+
5+
// CHECK: !{void ()* @Entry, !"Entry", null, null, ![[ENTRY:[0-9]+]]}
6+
// CHECK: ![[ENTRY]] = !{i32 8, i32 15, i32 13, i32 1, i32 14, i1 true, i32 15, ![[NodeId:[0-9]+]], i32 16, i32 -1, i32 18, ![[NodeDispatchGrid:[0-9]+]], i32 20, ![[NodeInputs:[0-9]+]], i32 4, ![[NumThreads:[0-9]+]], i32 5, ![[AutoBindingSpace:[0-9]+]]}
7+
// CHECK: ![[NodeId]] = !{!"Entry", i32 0}
8+
// CHECK: ![[NodeDispatchGrid]] = !{i32 1, i32 1, i32 1}
9+
// CHECK: ![[NodeInputs]] = !{![[Input0:[0-9]+]]}
10+
// CHECK: ![[Input0]] = !{i32 1, i32 97, i32 2, ![[NodeRecordType:[0-9]+]]}
11+
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68}
12+
// CHECK: ![[NumThreads]] = !{i32 32, i32 1, i32 1}
13+
// CHECK: ![[AutoBindingSpace]] = !{i32 0}
14+
15+
static const int maxPoints = 8;
16+
17+
struct EntryRecord {
18+
float2 points[maxPoints];
19+
int pointCoint;
20+
};
21+
22+
[shader("node")]
23+
[NodeIsProgramEntry]
24+
[NodeLaunch("broadcasting")]
25+
[NodeDispatchGrid(1, 1, 1)]
26+
[NumThreads(32, 1, 1)]
27+
void Entry(
28+
uint gtid : SV_GroupThreadId,
29+
DispatchNodeInputRecord<EntryRecord> inputData
30+
)
31+
{
32+
EntryRecord input = inputData.Get();
33+
34+
[[unroll]]
35+
for (int i = 0; i < 8; ++i) {
36+
float2 p = input.points[i];
37+
}
38+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %dxc -Tlib_6_8 %s | FileCheck %s
2+
3+
// Make sure generate correct metadata for Entry.
4+
5+
// CHECK: !{void ()* @Entry, !"Entry", null, null, ![[ENTRY:[0-9]+]]}
6+
// CHECK: ![[ENTRY]] = !{i32 8, i32 15, i32 13, i32 1, i32 14, i1 true, i32 15, ![[NodeId:[0-9]+]], i32 16, i32 -1, i32 18, ![[NodeDispatchGrid:[0-9]+]], i32 20, ![[NodeInputs:[0-9]+]], i32 4, ![[NumThreads:[0-9]+]], i32 5, ![[AutoBindingSpace:[0-9]+]]}
7+
// CHECK: ![[NodeId]] = !{!"Entry", i32 0}
8+
// CHECK: ![[NodeDispatchGrid]] = !{i32 1, i32 1, i32 1}
9+
// CHECK: ![[NodeInputs]] = !{![[Input0:[0-9]+]]}
10+
// CHECK: ![[Input0]] = !{i32 1, i32 97, i32 2, ![[NodeRecordType:[0-9]+]]}
11+
// CHECK: ![[NodeRecordType]] = !{i32 0, i32 68}
12+
// CHECK: ![[NumThreads]] = !{i32 32, i32 1, i32 1}
13+
// CHECK: ![[AutoBindingSpace]] = !{i32 0}
14+
15+
static const int maxPoints = 8;
16+
17+
struct EntryRecord {
18+
float2 points[maxPoints];
19+
int pointCoint;
20+
};
21+
22+
[shader("node")]
23+
[NodeIsProgramEntry]
24+
[NodeLaunch("broadcasting")]
25+
[NodeDispatchGrid(1, 1, 1)]
26+
[NumThreads(32, 1, 1)]
27+
void Entry(
28+
uint gtid : SV_GroupThreadId,
29+
DispatchNodeInputRecord<EntryRecord> inputData
30+
)
31+
{
32+
EntryRecord input = inputData.Get();
33+
34+
if (gtid < input.pointCoint) {
35+
// reading input.points[0] works
36+
float2 p = input.points[gtid];
37+
}
38+
}

0 commit comments

Comments
 (0)