@@ -81,73 +81,6 @@ class AMDGPULateCodeGenPrepare
8181 bool visitLoadInst (LoadInst &LI);
8282};
8383
84- using ValueToValueMap = DenseMap<const Value *, Value *>;
85-
86- class LiveRegOptimizer {
87- private:
88- Module *Mod = nullptr ;
89- const DataLayout *DL = nullptr ;
90- const GCNSubtarget *ST;
91- // / The scalar type to convert to
92- Type *ConvertToScalar;
93- // / The set of visited Instructions
94- SmallPtrSet<Instruction *, 4 > Visited;
95- // / The set of Instructions to be deleted
96- SmallPtrSet<Instruction *, 4 > DeadInstrs;
97- // / Map of Value -> Converted Value
98- ValueToValueMap ValMap;
99- // / Map of containing conversions from Optimal Type -> Original Type per BB.
100- DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101-
102- public:
103- // / Calculate the and \p return the type to convert to given a problematic \p
104- // / OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105- Type *calculateConvertType (Type *OriginalType);
106- // / Convert the virtual register defined by \p V to the compatible vector of
107- // / legal type
108- Value *convertToOptType (Instruction *V, BasicBlock::iterator &InstPt);
109- // / Convert the virtual register defined by \p V back to the original type \p
110- // / ConvertType, stripping away the MSBs in cases where there was an imperfect
111- // / fit (e.g. v2i32 -> v7i8)
112- Value *convertFromOptType (Type *ConvertType, Instruction *V,
113- BasicBlock::iterator &InstPt,
114- BasicBlock *InsertBlock);
115- // / Check for problematic PHI nodes or cross-bb values based on the value
116- // / defined by \p I, and coerce to legal types if necessary. For problematic
117- // / PHI node, we coerce all incoming values in a single invocation.
118- bool optimizeLiveType (Instruction *I);
119-
120- // / Remove all instructions that have become dead (i.e. all the re-typed PHIs)
121- void removeDeadInstrs ();
122-
123- // Whether or not the type should be replaced to avoid inefficient
124- // legalization code
125- bool shouldReplace (Type *ITy) {
126- FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
127- if (!VTy)
128- return false ;
129-
130- auto TLI = ST->getTargetLowering ();
131-
132- Type *EltTy = VTy->getElementType ();
133- // If the element size is not less than the convert to scalar size, then we
134- // can't do any bit packing
135- if (!EltTy->isIntegerTy () ||
136- EltTy->getScalarSizeInBits () > ConvertToScalar->getScalarSizeInBits ())
137- return false ;
138-
139- // Only coerce illegal types
140- TargetLoweringBase::LegalizeKind LK =
141- TLI->getTypeConversion (EltTy->getContext (), EVT::getEVT (EltTy, false ));
142- return LK.first != TargetLoweringBase::TypeLegal;
143- }
144-
145- LiveRegOptimizer (Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
146- DL = &Mod->getDataLayout ();
147- ConvertToScalar = Type::getInt32Ty (Mod->getContext ());
148- }
149- };
150-
15184} // end anonymous namespace
15285
15386bool AMDGPULateCodeGenPrepare::doInitialization (Module &M) {
@@ -169,238 +102,14 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
169102 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache (F);
170103 UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo ();
171104
172- // "Optimize" the virtual regs that cross basic block boundaries. When
173- // building the SelectionDAG, vectors of illegal types that cross basic blocks
174- // will be scalarized and widened, with each scalar living in its
175- // own register. To work around this, this optimization converts the
176- // vectors to equivalent vectors of legal type (which are converted back
177- // before uses in subsequent blocks), to pack the bits into fewer physical
178- // registers (used in CopyToReg/CopyFromReg pairs).
179- LiveRegOptimizer LRO (Mod, &ST);
180-
181105 bool Changed = false ;
182-
183106 for (auto &BB : F)
184- for (Instruction &I : make_early_inc_range (BB)) {
107+ for (Instruction &I : llvm:: make_early_inc_range (BB))
185108 Changed |= visit (I);
186- Changed |= LRO.optimizeLiveType (&I);
187- }
188109
189- LRO.removeDeadInstrs ();
190110 return Changed;
191111}
192112
193- Type *LiveRegOptimizer::calculateConvertType (Type *OriginalType) {
194- assert (OriginalType->getScalarSizeInBits () <=
195- ConvertToScalar->getScalarSizeInBits ());
196-
197- FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
198-
199- TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
200- TypeSize ConvertScalarSize = DL->getTypeSizeInBits (ConvertToScalar);
201- unsigned ConvertEltCount =
202- (OriginalSize + ConvertScalarSize - 1 ) / ConvertScalarSize;
203-
204- if (OriginalSize <= ConvertScalarSize)
205- return IntegerType::get (Mod->getContext (), ConvertScalarSize);
206-
207- return VectorType::get (Type::getIntNTy (Mod->getContext (), ConvertScalarSize),
208- ConvertEltCount, false );
209- }
210-
211- Value *LiveRegOptimizer::convertToOptType (Instruction *V,
212- BasicBlock::iterator &InsertPt) {
213- FixedVectorType *VTy = cast<FixedVectorType>(V->getType ());
214- Type *NewTy = calculateConvertType (V->getType ());
215-
216- TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
217- TypeSize NewSize = DL->getTypeSizeInBits (NewTy);
218-
219- IRBuilder<> Builder (V->getParent (), InsertPt);
220- // If there is a bitsize match, we can fit the old vector into a new vector of
221- // desired type.
222- if (OriginalSize == NewSize)
223- return Builder.CreateBitCast (V, NewTy, V->getName () + " .bc" );
224-
225- // If there is a bitsize mismatch, we must use a wider vector.
226- assert (NewSize > OriginalSize);
227- uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits ();
228-
229- SmallVector<int , 8 > ShuffleMask;
230- uint64_t OriginalElementCount = VTy->getElementCount ().getFixedValue ();
231- for (unsigned I = 0 ; I < OriginalElementCount; I++)
232- ShuffleMask.push_back (I);
233-
234- for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
235- ShuffleMask.push_back (OriginalElementCount);
236-
237- Value *ExpandedVec = Builder.CreateShuffleVector (V, ShuffleMask);
238- return Builder.CreateBitCast (ExpandedVec, NewTy, V->getName () + " .bc" );
239- }
240-
241- Value *LiveRegOptimizer::convertFromOptType (Type *ConvertType, Instruction *V,
242- BasicBlock::iterator &InsertPt,
243- BasicBlock *InsertBB) {
244- FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
245-
246- TypeSize OriginalSize = DL->getTypeSizeInBits (V->getType ());
247- TypeSize NewSize = DL->getTypeSizeInBits (NewVTy);
248-
249- IRBuilder<> Builder (InsertBB, InsertPt);
250- // If there is a bitsize match, we simply convert back to the original type.
251- if (OriginalSize == NewSize)
252- return Builder.CreateBitCast (V, NewVTy, V->getName () + " .bc" );
253-
254- // If there is a bitsize mismatch, then we must have used a wider value to
255- // hold the bits.
256- assert (OriginalSize > NewSize);
257- // For wide scalars, we can just truncate the value.
258- if (!V->getType ()->isVectorTy ()) {
259- Instruction *Trunc = cast<Instruction>(
260- Builder.CreateTrunc (V, IntegerType::get (Mod->getContext (), NewSize)));
261- return cast<Instruction>(Builder.CreateBitCast (Trunc, NewVTy));
262- }
263-
264- // For wider vectors, we must strip the MSBs to convert back to the original
265- // type.
266- VectorType *ExpandedVT = VectorType::get (
267- Type::getIntNTy (Mod->getContext (), NewVTy->getScalarSizeInBits ()),
268- (OriginalSize / NewVTy->getScalarSizeInBits ()), false );
269- Instruction *Converted =
270- cast<Instruction>(Builder.CreateBitCast (V, ExpandedVT));
271-
272- unsigned NarrowElementCount = NewVTy->getElementCount ().getFixedValue ();
273- SmallVector<int , 8 > ShuffleMask (NarrowElementCount);
274- std::iota (ShuffleMask.begin (), ShuffleMask.end (), 0 );
275-
276- return Builder.CreateShuffleVector (Converted, ShuffleMask);
277- }
278-
279- bool LiveRegOptimizer::optimizeLiveType (Instruction *I) {
280- SmallVector<Instruction *, 4 > Worklist;
281- SmallPtrSet<PHINode *, 4 > PhiNodes;
282- SmallPtrSet<Instruction *, 4 > Defs;
283- SmallPtrSet<Instruction *, 4 > Uses;
284-
285- Worklist.push_back (cast<Instruction>(I));
286- while (!Worklist.empty ()) {
287- Instruction *II = Worklist.pop_back_val ();
288-
289- if (!Visited.insert (II).second )
290- continue ;
291-
292- if (!shouldReplace (II->getType ()))
293- continue ;
294-
295- if (PHINode *Phi = dyn_cast<PHINode>(II)) {
296- PhiNodes.insert (Phi);
297- // Collect all the incoming values of problematic PHI nodes.
298- for (Value *V : Phi->incoming_values ()) {
299- // Repeat the collection process for newly found PHI nodes.
300- if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
301- if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
302- Worklist.push_back (OpPhi);
303- continue ;
304- }
305-
306- Instruction *IncInst = dyn_cast<Instruction>(V);
307- // Other incoming value types (e.g. vector literals) are unhandled
308- if (!IncInst && !isa<ConstantAggregateZero>(V))
309- return false ;
310-
311- // Collect all other incoming values for coercion.
312- if (IncInst)
313- Defs.insert (IncInst);
314- }
315- }
316-
317- // Collect all relevant uses.
318- for (User *V : II->users ()) {
319- // Repeat the collection process for problematic PHI nodes.
320- if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
321- if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
322- Worklist.push_back (OpPhi);
323- continue ;
324- }
325-
326- Instruction *UseInst = cast<Instruction>(V);
327- // Collect all uses of PHINodes and any use the crosses BB boundaries.
328- if (UseInst->getParent () != II->getParent () || isa<PHINode>(II)) {
329- Uses.insert (UseInst);
330- if (!Defs.count (II) && !isa<PHINode>(II)) {
331- Defs.insert (II);
332- }
333- }
334- }
335- }
336-
337- // Coerce and track the defs.
338- for (Instruction *D : Defs) {
339- if (!ValMap.contains (D)) {
340- BasicBlock::iterator InsertPt = std::next (D->getIterator ());
341- Value *ConvertVal = convertToOptType (D, InsertPt);
342- assert (ConvertVal);
343- ValMap[D] = ConvertVal;
344- }
345- }
346-
347- // Construct new-typed PHI nodes.
348- for (PHINode *Phi : PhiNodes) {
349- ValMap[Phi] = PHINode::Create (calculateConvertType (Phi->getType ()),
350- Phi->getNumIncomingValues (),
351- Phi->getName () + " .tc" , Phi->getIterator ());
352- }
353-
354- // Connect all the PHI nodes with their new incoming values.
355- for (PHINode *Phi : PhiNodes) {
356- PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
357- bool MissingIncVal = false ;
358- for (int I = 0 , E = Phi->getNumIncomingValues (); I < E; I++) {
359- Value *IncVal = Phi->getIncomingValue (I);
360- if (isa<ConstantAggregateZero>(IncVal)) {
361- Type *NewType = calculateConvertType (Phi->getType ());
362- NewPhi->addIncoming (ConstantInt::get (NewType, 0 , false ),
363- Phi->getIncomingBlock (I));
364- } else if (ValMap.contains (IncVal))
365- NewPhi->addIncoming (ValMap[IncVal], Phi->getIncomingBlock (I));
366- else
367- MissingIncVal = true ;
368- }
369- DeadInstrs.insert (MissingIncVal ? cast<Instruction>(ValMap[Phi]) : Phi);
370- }
371- // Coerce back to the original type and replace the uses.
372- for (Instruction *U : Uses) {
373- // Replace all converted operands for a use.
374- for (auto [OpIdx, Op] : enumerate(U->operands ())) {
375- if (ValMap.contains (Op)) {
376- Value *NewVal = nullptr ;
377- if (BBUseValMap.contains (U->getParent ()) &&
378- BBUseValMap[U->getParent ()].contains (ValMap[Op]))
379- NewVal = BBUseValMap[U->getParent ()][ValMap[Op]];
380- else {
381- BasicBlock::iterator InsertPt = U->getParent ()->getFirstNonPHIIt ();
382- NewVal =
383- convertFromOptType (Op->getType (), cast<Instruction>(ValMap[Op]),
384- InsertPt, U->getParent ());
385- BBUseValMap[U->getParent ()][ValMap[Op]] = NewVal;
386- }
387- assert (NewVal);
388- U->setOperand (OpIdx, NewVal);
389- }
390- }
391- }
392-
393- return true ;
394- }
395-
396- void LiveRegOptimizer::removeDeadInstrs () {
397- // Remove instrs that have been marked dead after type-coercion.
398- for (auto *I : DeadInstrs) {
399- I->replaceAllUsesWith (PoisonValue::get (I->getType ()));
400- I->eraseFromParent ();
401- }
402- }
403-
404113bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad (LoadInst &LI) const {
405114 unsigned AS = LI.getPointerAddressSpace ();
406115 // Skip non-constant address space.
@@ -410,7 +119,7 @@ bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
410119 // Skip non-simple loads.
411120 if (!LI.isSimple ())
412121 return false ;
413- Type *Ty = LI.getType ();
122+ auto *Ty = LI.getType ();
414123 // Skip aggregate types.
415124 if (Ty->isAggregateType ())
416125 return false ;
0 commit comments