1212#include " llvm/IR/IRBuilder.h"
1313#include " llvm/IR/InstIterator.h"
1414#include " llvm/IR/Instruction.h"
15+ #include " llvm/IR/Instructions.h"
1516#include " llvm/Pass.h"
1617#include " llvm/Transforms/Utils/BasicBlockUtils.h"
1718#include < functional>
@@ -31,16 +32,17 @@ static void legalizeFreeze(Instruction &I,
3132 ToRemove.push_back (FI);
3233}
3334
34- static void fixI8TruncUseChain (Instruction &I,
35- SmallVectorImpl<Instruction *> &ToRemove,
36- DenseMap<Value *, Value *> &ReplacedValues) {
35+ static void fixI8UseChain (Instruction &I,
36+ SmallVectorImpl<Instruction *> &ToRemove,
37+ DenseMap<Value *, Value *> &ReplacedValues) {
3738
3839 auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
3940 Type *InstrType = IntegerType::get (I.getContext (), 32 );
4041
4142 for (unsigned OpIdx = 0 ; OpIdx < I.getNumOperands (); ++OpIdx) {
4243 Value *Op = I.getOperand (OpIdx);
43- if (ReplacedValues.count (Op))
44+ if (ReplacedValues.count (Op) &&
45+ ReplacedValues[Op]->getType ()->isIntegerTy ())
4446 InstrType = ReplacedValues[Op]->getType ();
4547 }
4648
@@ -73,6 +75,31 @@ static void fixI8TruncUseChain(Instruction &I,
7375 }
7476 }
7577
78+ if (auto *Store = dyn_cast<StoreInst>(&I)) {
79+ if (!Store->getValueOperand ()->getType ()->isIntegerTy (8 ))
80+ return ;
81+ SmallVector<Value *> NewOperands;
82+ ProcessOperands (NewOperands);
83+ Value *NewStore = Builder.CreateStore (NewOperands[0 ], NewOperands[1 ]);
84+ ReplacedValues[Store] = NewStore;
85+ ToRemove.push_back (Store);
86+ return ;
87+ }
88+
89+ if (auto *Load = dyn_cast<LoadInst>(&I)) {
90+ if (!I.getType ()->isIntegerTy (8 ))
91+ return ;
92+ SmallVector<Value *> NewOperands;
93+ ProcessOperands (NewOperands);
94+ Type *ElementType = NewOperands[0 ]->getType ();
95+ if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0 ]))
96+ ElementType = AI->getAllocatedType ();
97+ LoadInst *NewLoad = Builder.CreateLoad (ElementType, NewOperands[0 ]);
98+ ReplacedValues[Load] = NewLoad;
99+ ToRemove.push_back (Load);
100+ return ;
101+ }
102+
76103 if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
77104 if (!I.getType ()->isIntegerTy (8 ))
78105 return ;
@@ -81,16 +108,29 @@ static void fixI8TruncUseChain(Instruction &I,
81108 Value *NewInst =
82109 Builder.CreateBinOp (BO->getOpcode (), NewOperands[0 ], NewOperands[1 ]);
83110 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
84- if (OBO->hasNoSignedWrap ())
85- cast<BinaryOperator>(NewInst)->setHasNoSignedWrap ();
86- if (OBO->hasNoUnsignedWrap ())
87- cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap ();
111+ auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
112+ if (NewBO && OBO->hasNoSignedWrap ())
113+ NewBO->setHasNoSignedWrap ();
114+ if (NewBO && OBO->hasNoUnsignedWrap ())
115+ NewBO->setHasNoUnsignedWrap ();
88116 }
89117 ReplacedValues[BO] = NewInst;
90118 ToRemove.push_back (BO);
91119 return ;
92120 }
93121
122+ if (auto *Sel = dyn_cast<SelectInst>(&I)) {
123+ if (!I.getType ()->isIntegerTy (8 ))
124+ return ;
125+ SmallVector<Value *> NewOperands;
126+ ProcessOperands (NewOperands);
127+ Value *NewInst = Builder.CreateSelect (Sel->getCondition (), NewOperands[1 ],
128+ NewOperands[2 ]);
129+ ReplacedValues[Sel] = NewInst;
130+ ToRemove.push_back (Sel);
131+ return ;
132+ }
133+
94134 if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
95135 if (!Cmp->getOperand (0 )->getType ()->isIntegerTy (8 ))
96136 return ;
@@ -105,13 +145,61 @@ static void fixI8TruncUseChain(Instruction &I,
105145 }
106146
107147 if (auto *Cast = dyn_cast<CastInst>(&I)) {
108- if (Cast->getSrcTy ()->isIntegerTy (8 )) {
109- ToRemove.push_back (Cast);
110- Cast->replaceAllUsesWith (ReplacedValues[Cast->getOperand (0 )]);
148+ if (!Cast->getSrcTy ()->isIntegerTy (8 ))
149+ return ;
150+
151+ ToRemove.push_back (Cast);
152+ auto *Replacement = ReplacedValues[Cast->getOperand (0 )];
153+ if (Cast->getType () == Replacement->getType ()) {
154+ Cast->replaceAllUsesWith (Replacement);
155+ return ;
111156 }
157+ Value *AdjustedCast = nullptr ;
158+ if (Cast->getOpcode () == Instruction::ZExt)
159+ AdjustedCast = Builder.CreateZExtOrTrunc (Replacement, Cast->getType ());
160+ if (Cast->getOpcode () == Instruction::SExt)
161+ AdjustedCast = Builder.CreateSExtOrTrunc (Replacement, Cast->getType ());
162+
163+ if (AdjustedCast)
164+ Cast->replaceAllUsesWith (AdjustedCast);
112165 }
113166}
114167
168+ static void upcastI8AllocasAndUses (Instruction &I,
169+ SmallVectorImpl<Instruction *> &ToRemove,
170+ DenseMap<Value *, Value *> &ReplacedValues) {
171+ auto *AI = dyn_cast<AllocaInst>(&I);
172+ if (!AI || !AI->getAllocatedType ()->isIntegerTy (8 ))
173+ return ;
174+
175+ Type *SmallestType = nullptr ;
176+
177+ // Gather all cast targets
178+ for (User *U : AI->users ()) {
179+ auto *Load = dyn_cast<LoadInst>(U);
180+ if (!Load)
181+ continue ;
182+ for (User *LU : Load->users ()) {
183+ auto *Cast = dyn_cast<CastInst>(LU);
184+ if (!Cast)
185+ continue ;
186+ Type *Ty = Cast->getType ();
187+ if (!SmallestType ||
188+ Ty->getPrimitiveSizeInBits () < SmallestType->getPrimitiveSizeInBits ())
189+ SmallestType = Ty;
190+ }
191+ }
192+
193+ if (!SmallestType)
194+ return ; // no valid casts found
195+
196+ // Replace alloca
197+ IRBuilder<> Builder (AI);
198+ auto *NewAlloca = Builder.CreateAlloca (SmallestType);
199+ ReplacedValues[AI] = NewAlloca;
200+ ToRemove.push_back (AI);
201+ }
202+
115203static void
116204downcastI64toI32InsertExtractElements (Instruction &I,
117205 SmallVectorImpl<Instruction *> &ToRemove,
@@ -178,7 +266,8 @@ class DXILLegalizationPipeline {
178266 LegalizationPipeline;
179267
180268 void initializeLegalizationPipeline () {
181- LegalizationPipeline.push_back (fixI8TruncUseChain);
269+ LegalizationPipeline.push_back (upcastI8AllocasAndUses);
270+ LegalizationPipeline.push_back (fixI8UseChain);
182271 LegalizationPipeline.push_back (downcastI64toI32InsertExtractElements);
183272 LegalizationPipeline.push_back (legalizeFreeze);
184273 }
0 commit comments