Skip to content

Commit a293aa6

Browse files
committed
Make DXILLegalizePass not dependent on ToRemove for change status
1 parent f73c390 commit a293aa6

File tree

1 file changed

+65
-59
lines changed

1 file changed

+65
-59
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,21 @@
2424

2525
using namespace llvm;
2626

27-
static void legalizeFreeze(Instruction &I,
27+
static bool legalizeFreeze(Instruction &I,
2828
SmallVectorImpl<Instruction *> &ToRemove,
29-
DenseMap<Value *, Value *>, bool &) {
29+
DenseMap<Value *, Value *>) {
3030
auto *FI = dyn_cast<FreezeInst>(&I);
3131
if (!FI)
32-
return;
32+
return false;
3333

3434
FI->replaceAllUsesWith(FI->getOperand(0));
3535
ToRemove.push_back(FI);
36+
return true;
3637
}
3738

38-
static void fixI8UseChain(Instruction &I,
39+
static bool fixI8UseChain(Instruction &I,
3940
SmallVectorImpl<Instruction *> &ToRemove,
40-
DenseMap<Value *, Value *> &ReplacedValues, bool &) {
41+
DenseMap<Value *, Value *> &ReplacedValues) {
4142

4243
auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
4344
Type *InstrType = IntegerType::get(I.getContext(), 32);
@@ -74,19 +75,19 @@ static void fixI8UseChain(Instruction &I,
7475
if (Trunc->getDestTy()->isIntegerTy(8)) {
7576
ReplacedValues[Trunc] = Trunc->getOperand(0);
7677
ToRemove.push_back(Trunc);
77-
return;
78+
return true;
7879
}
7980
}
8081

8182
if (auto *Store = dyn_cast<StoreInst>(&I)) {
8283
if (!Store->getValueOperand()->getType()->isIntegerTy(8))
83-
return;
84+
return false;
8485
SmallVector<Value *> NewOperands;
8586
ProcessOperands(NewOperands);
8687
Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
8788
ReplacedValues[Store] = NewStore;
8889
ToRemove.push_back(Store);
89-
return;
90+
return true;
9091
}
9192

9293
if (auto *Load = dyn_cast<LoadInst>(&I);
@@ -104,17 +105,17 @@ static void fixI8UseChain(Instruction &I,
104105
LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
105106
ReplacedValues[Load] = NewLoad;
106107
ToRemove.push_back(Load);
107-
return;
108+
return true;
108109
}
109110

110111
if (auto *Load = dyn_cast<LoadInst>(&I);
111112
Load && isa<ConstantExpr>(Load->getPointerOperand())) {
112113
auto *CE = dyn_cast<ConstantExpr>(Load->getPointerOperand());
113114
if (!(CE->getOpcode() == Instruction::GetElementPtr))
114-
return;
115+
return false;
115116
auto *GEP = dyn_cast<GEPOperator>(CE);
116117
if (!GEP->getSourceElementType()->isIntegerTy(8))
117-
return;
118+
return false;
118119

119120
Type *ElementType = Load->getType();
120121
ConstantInt *Offset = dyn_cast<ConstantInt>(GEP->getOperand(1));
@@ -143,12 +144,12 @@ static void fixI8UseChain(Instruction &I,
143144
ReplacedValues[Load] = NewLoad;
144145
Load->replaceAllUsesWith(NewLoad);
145146
ToRemove.push_back(Load);
146-
return;
147+
return true;
147148
}
148149

149150
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
150151
if (!I.getType()->isIntegerTy(8))
151-
return;
152+
return false;
152153
SmallVector<Value *> NewOperands;
153154
ProcessOperands(NewOperands);
154155
Value *NewInst =
@@ -162,43 +163,43 @@ static void fixI8UseChain(Instruction &I,
162163
}
163164
ReplacedValues[BO] = NewInst;
164165
ToRemove.push_back(BO);
165-
return;
166+
return true;
166167
}
167168

168169
if (auto *Sel = dyn_cast<SelectInst>(&I)) {
169170
if (!I.getType()->isIntegerTy(8))
170-
return;
171+
return false;
171172
SmallVector<Value *> NewOperands;
172173
ProcessOperands(NewOperands);
173174
Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
174175
NewOperands[2]);
175176
ReplacedValues[Sel] = NewInst;
176177
ToRemove.push_back(Sel);
177-
return;
178+
return true;
178179
}
179180

180181
if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
181182
if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
182-
return;
183+
return false;
183184
SmallVector<Value *> NewOperands;
184185
ProcessOperands(NewOperands);
185186
Value *NewInst =
186187
Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]);
187188
Cmp->replaceAllUsesWith(NewInst);
188189
ReplacedValues[Cmp] = NewInst;
189190
ToRemove.push_back(Cmp);
190-
return;
191+
return true;
191192
}
192193

193194
if (auto *Cast = dyn_cast<CastInst>(&I)) {
194195
if (!Cast->getSrcTy()->isIntegerTy(8))
195-
return;
196+
return false;
196197

197198
ToRemove.push_back(Cast);
198199
auto *Replacement = ReplacedValues[Cast->getOperand(0)];
199200
if (Cast->getType() == Replacement->getType()) {
200201
Cast->replaceAllUsesWith(Replacement);
201-
return;
202+
return true;
202203
}
203204

204205
Value *AdjustedCast = nullptr;
@@ -213,7 +214,7 @@ static void fixI8UseChain(Instruction &I,
213214
if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
214215
if (!GEP->getType()->isPointerTy() ||
215216
!GEP->getSourceElementType()->isIntegerTy(8))
216-
return;
217+
return false;
217218

218219
Value *BasePtr = GEP->getPointerOperand();
219220
if (ReplacedValues.count(BasePtr))
@@ -248,16 +249,17 @@ static void fixI8UseChain(Instruction &I,
248249
ReplacedValues[GEP] = NewGEP;
249250
GEP->replaceAllUsesWith(NewGEP);
250251
ToRemove.push_back(GEP);
252+
return true;
251253
}
254+
return false;
252255
}
253256

254-
static void upcastI8AllocasAndUses(Instruction &I,
257+
static bool upcastI8AllocasAndUses(Instruction &I,
255258
SmallVectorImpl<Instruction *> &ToRemove,
256-
DenseMap<Value *, Value *> &ReplacedValues,
257-
bool &) {
259+
DenseMap<Value *, Value *> &ReplacedValues) {
258260
auto *AI = dyn_cast<AllocaInst>(&I);
259261
if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
260-
return;
262+
return false;
261263

262264
Type *SmallestType = nullptr;
263265

@@ -292,19 +294,20 @@ static void upcastI8AllocasAndUses(Instruction &I,
292294
}
293295

294296
if (!SmallestType)
295-
return; // no valid casts found
297+
return false; // no valid casts found
296298

297299
// Replace alloca
298300
IRBuilder<> Builder(AI);
299301
auto *NewAlloca = Builder.CreateAlloca(SmallestType);
300302
ReplacedValues[AI] = NewAlloca;
301303
ToRemove.push_back(AI);
304+
return true;
302305
}
303306

304-
static void
307+
static bool
305308
downcastI64toI32InsertExtractElements(Instruction &I,
306309
SmallVectorImpl<Instruction *> &ToRemove,
307-
DenseMap<Value *, Value *> &, bool &) {
310+
DenseMap<Value *, Value *> &) {
308311

309312
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
310313
Value *Idx = Extract->getIndexOperand();
@@ -319,6 +322,7 @@ downcastI64toI32InsertExtractElements(Instruction &I,
319322

320323
Extract->replaceAllUsesWith(NewExtract);
321324
ToRemove.push_back(Extract);
325+
return true;
322326
}
323327
}
324328

@@ -336,8 +340,10 @@ downcastI64toI32InsertExtractElements(Instruction &I,
336340

337341
Insert->replaceAllUsesWith(Insert32Index);
338342
ToRemove.push_back(Insert);
343+
return true;
339344
}
340345
}
346+
return false;
341347
}
342348

343349
static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
@@ -454,17 +460,17 @@ static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
454460
// Expands the instruction `I` into corresponding loads and stores if it is a
455461
// memcpy call. In that case, the call instruction is added to the `ToRemove`
456462
// vector. `ReplacedValues` is unused.
457-
static void legalizeMemCpy(Instruction &I,
463+
static bool legalizeMemCpy(Instruction &I,
458464
SmallVectorImpl<Instruction *> &ToRemove,
459-
DenseMap<Value *, Value *> &ReplacedValues, bool &) {
465+
DenseMap<Value *, Value *> &ReplacedValues) {
460466

461467
CallInst *CI = dyn_cast<CallInst>(&I);
462468
if (!CI)
463-
return;
469+
return false;
464470

465471
Intrinsic::ID ID = CI->getIntrinsicID();
466472
if (ID != Intrinsic::memcpy)
467-
return;
473+
return false;
468474

469475
IRBuilder<> Builder(&I);
470476
Value *Dst = CI->getArgOperand(0);
@@ -477,19 +483,20 @@ static void legalizeMemCpy(Instruction &I,
477483
assert(IsVolatile->getZExtValue() == 0 && "Expected IsVolatile to be false");
478484
emitMemcpyExpansion(Builder, Dst, Src, Length);
479485
ToRemove.push_back(CI);
486+
return true;
480487
}
481488

482-
static void legalizeMemSet(Instruction &I,
489+
static bool legalizeMemSet(Instruction &I,
483490
SmallVectorImpl<Instruction *> &ToRemove,
484-
DenseMap<Value *, Value *> &ReplacedValues, bool &) {
491+
DenseMap<Value *, Value *> &ReplacedValues) {
485492

486493
CallInst *CI = dyn_cast<CallInst>(&I);
487494
if (!CI)
488-
return;
495+
return false;
489496

490497
Intrinsic::ID ID = CI->getIntrinsicID();
491498
if (ID != Intrinsic::memset)
492-
return;
499+
return false;
493500

494501
IRBuilder<> Builder(&I);
495502
Value *Dst = CI->getArgOperand(0);
@@ -498,39 +505,41 @@ static void legalizeMemSet(Instruction &I,
498505
assert(Size && "Expected Size to be a ConstantInt");
499506
emitMemsetExpansion(Builder, Dst, Val, Size, ReplacedValues);
500507
ToRemove.push_back(CI);
508+
return true;
501509
}
502510

503-
static void updateFnegToFsub(Instruction &I,
511+
static bool updateFnegToFsub(Instruction &I,
504512
SmallVectorImpl<Instruction *> &ToRemove,
505-
DenseMap<Value *, Value *> &, bool &) {
513+
DenseMap<Value *, Value *> &) {
506514
const Intrinsic::ID ID = I.getOpcode();
507515
if (ID != Instruction::FNeg)
508-
return;
516+
return false;
509517

510518
IRBuilder<> Builder(&I);
511519
Value *In = I.getOperand(0);
512520
Value *Zero = ConstantFP::get(In->getType(), -0.0);
513521
I.replaceAllUsesWith(Builder.CreateFSub(Zero, In));
514522
ToRemove.push_back(&I);
523+
return true;
515524
}
516525

517-
static void
526+
static bool
518527
legalizeGetHighLowi64Bytes(Instruction &I,
519528
SmallVectorImpl<Instruction *> &ToRemove,
520-
DenseMap<Value *, Value *> &ReplacedValues, bool &) {
529+
DenseMap<Value *, Value *> &ReplacedValues) {
521530
if (auto *BitCast = dyn_cast<BitCastInst>(&I)) {
522531
if (BitCast->getDestTy() ==
523532
FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) &&
524533
BitCast->getSrcTy()->isIntegerTy(64)) {
525534
ToRemove.push_back(BitCast);
526535
ReplacedValues[BitCast] = BitCast->getOperand(0);
527-
return;
536+
return true;
528537
}
529538
}
530539

531540
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
532541
if (!dyn_cast<BitCastInst>(Extract->getVectorOperand()))
533-
return;
542+
return false;
534543
auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType());
535544
if (VecTy && VecTy->getElementType()->isIntegerTy(32) &&
536545
VecTy->getNumElements() == 2) {
@@ -558,14 +567,16 @@ legalizeGetHighLowi64Bytes(Instruction &I,
558567
}
559568
ToRemove.push_back(Extract);
560569
Extract->replaceAllUsesWith(ReplacedValues[Extract]);
570+
return true;
561571
}
562572
}
563573
}
574+
return false;
564575
}
565576

566-
static void legalizeScalarLoadStoreOnArrays(
577+
static bool legalizeScalarLoadStoreOnArrays(
567578
Instruction &I, SmallVectorImpl<Instruction *> &ToRemove,
568-
DenseMap<Value *, Value *> &, bool &MadeChange) {
579+
DenseMap<Value *, Value *> &) {
569580

570581
Value *PtrOp;
571582
unsigned PtrOpIndex;
@@ -579,25 +590,25 @@ static void legalizeScalarLoadStoreOnArrays(
579590
PtrOpIndex = SI->getPointerOperandIndex();
580591
LoadStoreTy = SI->getValueOperand()->getType();
581592
} else
582-
return;
593+
return false;
583594

584595
// If the load/store is not of a single-value type (i.e., scalar or vector)
585596
// then we do not modify it. It shouldn't be a vector either because the
586597
// dxil-data-scalarization pass is expected to run before this, but it's not
587598
// incorrect to apply this transformation to vector load/stores.
588599
if (!LoadStoreTy->isSingleValueType())
589-
return;
600+
return false;
590601

591602
Type *ArrayTy;
592603
if (auto *GlobalVarPtrOp = dyn_cast<GlobalVariable>(PtrOp))
593604
ArrayTy = GlobalVarPtrOp->getValueType();
594605
else if (auto *AllocaPtrOp = dyn_cast<AllocaInst>(PtrOp))
595606
ArrayTy = AllocaPtrOp->getAllocatedType();
596607
else
597-
return;
608+
return false;
598609

599610
if (!isa<ArrayType>(ArrayTy))
600-
return;
611+
return false;
601612

602613
assert(ArrayTy->getArrayElementType() == LoadStoreTy &&
603614
"Expected array element type to be the same as to the scalar load or "
@@ -607,7 +618,7 @@ static void legalizeScalarLoadStoreOnArrays(
607618
Value *GEP = GetElementPtrInst::Create(
608619
ArrayTy, PtrOp, {Zero, Zero}, GEPNoWrapFlags::all(), "", I.getIterator());
609620
I.setOperand(PtrOpIndex, GEP);
610-
MadeChange = true;
621+
return true;
611622
}
612623

613624
namespace {
@@ -624,17 +635,12 @@ class DXILLegalizationPipeline {
624635
ToRemove.clear();
625636
ReplacedValues.clear();
626637
for (auto &I : instructions(F)) {
627-
for (auto &LegalizationFn : LegalizationPipeline[Stage]) {
628-
bool PerLegalizationChange = false;
629-
LegalizationFn(I, ToRemove, ReplacedValues, PerLegalizationChange);
630-
MadeChange |= PerLegalizationChange;
631-
}
638+
for (auto &LegalizationFn : LegalizationPipeline[Stage])
639+
MadeChange |= LegalizationFn(I, ToRemove, ReplacedValues);
632640
}
633641

634642
for (auto *Inst : reverse(ToRemove))
635643
Inst->eraseFromParent();
636-
637-
MadeChange |= !ToRemove.empty();
638644
}
639645
return MadeChange;
640646
}
@@ -643,8 +649,8 @@ class DXILLegalizationPipeline {
643649
enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages };
644650

645651
using LegalizationFnTy =
646-
std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
647-
DenseMap<Value *, Value *> &, bool &)>;
652+
std::function<bool(Instruction &, SmallVectorImpl<Instruction *> &,
653+
DenseMap<Value *, Value *> &)>;
648654

649655
SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages];
650656

0 commit comments

Comments
 (0)