55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===---------------------------------------------------------------------===//
8- // ===---------------------------------------------------------------------===//
9- // /
10- // / \file This file contains a pass to remove i8 truncations and i64 extract
11- // / and insert elements.
12- // /
13- // ===----------------------------------------------------------------------===//
8+
149#include " DXILLegalizePass.h"
1510#include " DirectX.h"
1611#include " llvm/IR/Function.h"
2015#include " llvm/Pass.h"
2116#include " llvm/Transforms/Utils/BasicBlockUtils.h"
2217#include < functional>
23- #include < map>
24- #include < stack>
25- #include < vector>
2618
2719#define DEBUG_TYPE " dxil-legalize"
2820
2921using namespace llvm ;
30- namespace {
3122
3223static void fixI8TruncUseChain (Instruction &I,
33- std::stack<Instruction *> &ToRemove,
34- std::map<Value *, Value *> &ReplacedValues) {
35-
36- auto *Cmp = dyn_cast<CmpInst>(&I);
24+ SmallVectorImpl<Instruction *> &ToRemove,
25+ DenseMap<Value *, Value *> &ReplacedValues) {
3726
38- if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
39- if (Trunc->getDestTy ()->isIntegerTy (8 )) {
40- ReplacedValues[Trunc] = Trunc->getOperand (0 );
41- ToRemove.push (Trunc);
42- }
43- } else if (I.getType ()->isIntegerTy (8 ) ||
44- (Cmp && Cmp->getOperand (0 )->getType ()->isIntegerTy (8 ))) {
45- IRBuilder<> Builder (&I);
46-
47- std::vector<Value *> NewOperands;
27+ auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
4828 Type *InstrType = IntegerType::get (I.getContext (), 32 );
29+
4930 for (unsigned OpIdx = 0 ; OpIdx < I.getNumOperands (); ++OpIdx) {
5031 Value *Op = I.getOperand (OpIdx);
5132 if (ReplacedValues.count (Op))
5233 InstrType = ReplacedValues[Op]->getType ();
5334 }
35+
5436 for (unsigned OpIdx = 0 ; OpIdx < I.getNumOperands (); ++OpIdx) {
5537 Value *Op = I.getOperand (OpIdx);
5638 if (ReplacedValues.count (Op))
@@ -61,47 +43,68 @@ static void fixI8TruncUseChain(Instruction &I,
6143 // Note: options here are sext or sextOrTrunc.
6244 // Since i8 isn't supported, we assume new values
6345 // will always have a higher bitness.
46+ assert (NewBitWidth > Value.getBitWidth () &&
47+ " Replacement's BitWidth should be larger than Current." );
6448 APInt NewValue = Value.sext (NewBitWidth);
6549 NewOperands.push_back (ConstantInt::get (InstrType, NewValue));
6650 } else {
6751 assert (!Op->getType ()->isIntegerTy (8 ));
6852 NewOperands.push_back (Op);
6953 }
7054 }
71-
72- Value *NewInst = nullptr ;
73- if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
74- NewInst =
75- Builder.CreateBinOp (BO->getOpcode (), NewOperands[0 ], NewOperands[1 ]);
76-
77- if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
78- if (OBO->hasNoSignedWrap ())
79- cast<BinaryOperator>(NewInst)->setHasNoSignedWrap ();
80- if (OBO->hasNoUnsignedWrap ())
81- cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap ();
82- }
83- } else if (Cmp) {
84- NewInst = Builder.CreateCmp (Cmp->getPredicate (), NewOperands[0 ],
85- NewOperands[1 ]);
86- Cmp->replaceAllUsesWith (NewInst);
55+ };
56+ IRBuilder<> Builder (&I);
57+ if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
58+ if (Trunc->getDestTy ()->isIntegerTy (8 )) {
59+ ReplacedValues[Trunc] = Trunc->getOperand (0 );
60+ ToRemove.push_back (Trunc);
61+ return ;
8762 }
63+ }
8864
89- if (NewInst) {
90- ReplacedValues[&I] = NewInst;
91- ToRemove.push (&I);
65+ if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
66+ if (!I.getType ()->isIntegerTy (8 ))
67+ return ;
68+ SmallVector<Value *> NewOperands;
69+ ProcessOperands (NewOperands);
70+ Value *NewInst =
71+ Builder.CreateBinOp (BO->getOpcode (), NewOperands[0 ], NewOperands[1 ]);
72+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
73+ if (OBO->hasNoSignedWrap ())
74+ cast<BinaryOperator>(NewInst)->setHasNoSignedWrap ();
75+ if (OBO->hasNoUnsignedWrap ())
76+ cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap ();
9277 }
93- } else if (auto *Cast = dyn_cast<CastInst>(&I)) {
78+ ReplacedValues[BO] = NewInst;
79+ ToRemove.push_back (BO);
80+ return ;
81+ }
82+
83+ if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
84+ if (!Cmp->getOperand (0 )->getType ()->isIntegerTy (8 ))
85+ return ;
86+ SmallVector<Value *> NewOperands;
87+ ProcessOperands (NewOperands);
88+ Value *NewInst =
89+ Builder.CreateCmp (Cmp->getPredicate (), NewOperands[0 ], NewOperands[1 ]);
90+ Cmp->replaceAllUsesWith (NewInst);
91+ ReplacedValues[Cmp] = NewInst;
92+ ToRemove.push_back (Cmp);
93+ return ;
94+ }
95+
96+ if (auto *Cast = dyn_cast<CastInst>(&I)) {
9497 if (Cast->getSrcTy ()->isIntegerTy (8 )) {
95- ToRemove.push (Cast);
98+ ToRemove.push_back (Cast);
9699 Cast->replaceAllUsesWith (ReplacedValues[Cast->getOperand (0 )]);
97100 }
98101 }
99102}
100103
101104static void
102105downcastI64toI32InsertExtractElements (Instruction &I,
103- std::stack <Instruction *> &ToRemove,
104- std::map <Value *, Value *> &) {
106+ SmallVectorImpl <Instruction *> &ToRemove,
107+ DenseMap <Value *, Value *> &) {
105108
106109 if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
107110 Value *Idx = Extract->getIndexOperand ();
@@ -115,7 +118,7 @@ downcastI64toI32InsertExtractElements(Instruction &I,
115118 Extract->getVectorOperand (), Idx32, Extract->getName ());
116119
117120 Extract->replaceAllUsesWith (NewExtract);
118- ToRemove.push (Extract);
121+ ToRemove.push_back (Extract);
119122 }
120123 }
121124
@@ -132,38 +135,35 @@ downcastI64toI32InsertExtractElements(Instruction &I,
132135 Insert->getName ());
133136
134137 Insert->replaceAllUsesWith (Insert32Index);
135- ToRemove.push (Insert);
138+ ToRemove.push_back (Insert);
136139 }
137140 }
138141}
139142
143+ namespace {
140144class DXILLegalizationPipeline {
141145
142146public:
143147 DXILLegalizationPipeline () { initializeLegalizationPipeline (); }
144148
145149 bool runLegalizationPipeline (Function &F) {
146- std::stack <Instruction *> ToRemove;
147- std::map <Value *, Value *> ReplacedValues;
150+ SmallVector <Instruction *> ToRemove;
151+ DenseMap <Value *, Value *> ReplacedValues;
148152 for (auto &I : instructions (F)) {
149- for (auto &LegalizationFn : LegalizationPipeline) {
153+ for (auto &LegalizationFn : LegalizationPipeline)
150154 LegalizationFn (I, ToRemove, ReplacedValues);
151- }
152155 }
153- bool MadeChanges = !ToRemove.empty ();
154156
155- while (!ToRemove.empty ()) {
156- Instruction *I = ToRemove.top ();
157- I->eraseFromParent ();
158- ToRemove.pop ();
159- }
157+ for (auto *Inst : reverse (ToRemove))
158+ Inst->eraseFromParent ();
160159
161- return MadeChanges ;
160+ return !ToRemove. empty () ;
162161 }
163162
164163private:
165- std::vector<std::function<void (Instruction &, std::stack<Instruction *> &,
166- std::map<Value *, Value *> &)>>
164+ SmallVector<
165+ std::function<void (Instruction &, SmallVectorImpl<Instruction *> &,
166+ DenseMap<Value *, Value *> &)>>
167167 LegalizationPipeline;
168168
169169 void initializeLegalizationPipeline () {
0 commit comments