@@ -50,6 +50,8 @@ class AMDGPULateCodeGenPrepare
5050 AssumptionCache *AC = nullptr ;
5151 UniformityInfo *UA = nullptr ;
5252
53+ SmallVector<WeakTrackingVH, 8 > DeadInsts;
54+
5355public:
5456 static char ID;
5557
@@ -81,6 +83,69 @@ class AMDGPULateCodeGenPrepare
8183 bool visitLoadInst (LoadInst &LI);
8284};
8385
86+ using ValueToValueMap = DenseMap<const Value *, Value *>;
87+
88+ class LiveRegOptimizer {
89+ private:
90+ Module *Mod = nullptr ;
91+ const DataLayout *DL = nullptr ;
92+ const GCNSubtarget *ST;
93+ // / The scalar type to convert to
94+ Type *ConvertToScalar;
95+ // / The set of visited Instructions
96+ SmallPtrSet<Instruction *, 4 > Visited;
97+ // / Map of Value -> Converted Value
98+ ValueToValueMap ValMap;
99+ // / Map of containing conversions from Optimal Type -> Original Type per BB.
100+ DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101+
102+ public:
103+ // / Calculate the and \p return the type to convert to given a problematic \p
104+ // / OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105+ Type *calculateConvertType (Type *OriginalType);
106+ // / Convert the virtual register defined by \p V to the compatible vector of
107+ // / legal type
108+ Value *convertToOptType (Instruction *V, BasicBlock::iterator &InstPt);
109+ // / Convert the virtual register defined by \p V back to the original type \p
110+ // / ConvertType, stripping away the MSBs in cases where there was an imperfect
111+ // / fit (e.g. v2i32 -> v7i8)
112+ Value *convertFromOptType (Type *ConvertType, Instruction *V,
113+ BasicBlock::iterator &InstPt,
114+ BasicBlock *InsertBlock);
115+ // / Check for problematic PHI nodes or cross-bb values based on the value
116+ // / defined by \p I, and coerce to legal types if necessary. For problematic
117+ // / PHI node, we coerce all incoming values in a single invocation.
118+ bool optimizeLiveType (Instruction *I,
119+ SmallVectorImpl<WeakTrackingVH> &DeadInsts);
120+
121+ // Whether or not the type should be replaced to avoid inefficient
122+ // legalization code
123+ bool shouldReplace (Type *ITy) {
124+ FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
125+ if (!VTy)
126+ return false ;
127+
128+ auto TLI = ST->getTargetLowering ();
129+
130+ Type *EltTy = VTy->getElementType ();
131+ // If the element size is not less than the convert to scalar size, then we
132+ // can't do any bit packing
133+ if (!EltTy->isIntegerTy () ||
134+ EltTy->getScalarSizeInBits () > ConvertToScalar->getScalarSizeInBits ())
135+ return false ;
136+
137+ // Only coerce illegal types
138+ TargetLoweringBase::LegalizeKind LK =
139+ TLI->getTypeConversion (EltTy->getContext (), EVT::getEVT (EltTy, false ));
140+ return LK.first != TargetLoweringBase::TypeLegal;
141+ }
142+
143+ LiveRegOptimizer (Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
144+ DL = &Mod->getDataLayout ();
145+ ConvertToScalar = Type::getInt32Ty (Mod->getContext ());
146+ }
147+ };
148+
84149} // end anonymous namespace
85150
86151bool AMDGPULateCodeGenPrepare::doInitialization (Module &M) {
@@ -96,20 +161,243 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
96161 const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
97162 const TargetMachine &TM = TPC.getTM <TargetMachine>();
98163 const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
99- if (ST.hasScalarSubwordLoads ())
100- return false ;
101164
102165 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache (F);
103166 UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo ();
104167
168+ // "Optimize" the virtual regs that cross basic block boundaries. When
169+ // building the SelectionDAG, vectors of illegal types that cross basic blocks
170+ // will be scalarized and widened, with each scalar living in its
171+ // own register. To work around this, this optimization converts the
172+ // vectors to equivalent vectors of legal type (which are converted back
173+ // before uses in subsequent blocks), to pack the bits into fewer physical
174+ // registers (used in CopyToReg/CopyFromReg pairs).
175+ LiveRegOptimizer LRO (Mod, &ST);
176+
105177 bool Changed = false ;
106- for (auto &BB : F)
107- for (Instruction &I : llvm::make_early_inc_range (BB))
108- Changed |= visit (I);
109178
179+ bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads ();
180+
181+ for (auto &BB : reverse (F))
182+ for (Instruction &I : make_early_inc_range (reverse (BB))) {
183+ Changed |= !HasScalarSubwordLoads && visit (I);
184+ Changed |= LRO.optimizeLiveType (&I, DeadInsts);
185+ }
186+
187+ RecursivelyDeleteTriviallyDeadInstructionsPermissive (DeadInsts);
110188 return Changed;
111189}
112190
191+ Type *LiveRegOptimizer::calculateConvertType (Type *OriginalType) {
192+ assert (OriginalType->getScalarSizeInBits () <=
193+ ConvertToScalar->getScalarSizeInBits ());
194+
195+ FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
196+
197+ TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
198+ TypeSize ConvertScalarSize = DL->getTypeSizeInBits (ConvertToScalar);
199+ unsigned ConvertEltCount =
200+ (OriginalSize + ConvertScalarSize - 1 ) / ConvertScalarSize;
201+
202+ if (OriginalSize <= ConvertScalarSize)
203+ return IntegerType::get (Mod->getContext (), ConvertScalarSize);
204+
205+ return VectorType::get (Type::getIntNTy (Mod->getContext (), ConvertScalarSize),
206+ ConvertEltCount, false );
207+ }
208+
209+ Value *LiveRegOptimizer::convertToOptType (Instruction *V,
210+ BasicBlock::iterator &InsertPt) {
211+ FixedVectorType *VTy = cast<FixedVectorType>(V->getType ());
212+ Type *NewTy = calculateConvertType (V->getType ());
213+
214+ TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
215+ TypeSize NewSize = DL->getTypeSizeInBits (NewTy);
216+
217+ IRBuilder<> Builder (V->getParent (), InsertPt);
218+ // If there is a bitsize match, we can fit the old vector into a new vector of
219+ // desired type.
220+ if (OriginalSize == NewSize)
221+ return Builder.CreateBitCast (V, NewTy, V->getName () + " .bc" );
222+
223+ // If there is a bitsize mismatch, we must use a wider vector.
224+ assert (NewSize > OriginalSize);
225+ uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits ();
226+
227+ SmallVector<int , 8 > ShuffleMask;
228+ uint64_t OriginalElementCount = VTy->getElementCount ().getFixedValue ();
229+ for (unsigned I = 0 ; I < OriginalElementCount; I++)
230+ ShuffleMask.push_back (I);
231+
232+ for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
233+ ShuffleMask.push_back (OriginalElementCount);
234+
235+ Value *ExpandedVec = Builder.CreateShuffleVector (V, ShuffleMask);
236+ return Builder.CreateBitCast (ExpandedVec, NewTy, V->getName () + " .bc" );
237+ }
238+
239+ Value *LiveRegOptimizer::convertFromOptType (Type *ConvertType, Instruction *V,
240+ BasicBlock::iterator &InsertPt,
241+ BasicBlock *InsertBB) {
242+ FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
243+
244+ TypeSize OriginalSize = DL->getTypeSizeInBits (V->getType ());
245+ TypeSize NewSize = DL->getTypeSizeInBits (NewVTy);
246+
247+ IRBuilder<> Builder (InsertBB, InsertPt);
248+ // If there is a bitsize match, we simply convert back to the original type.
249+ if (OriginalSize == NewSize)
250+ return Builder.CreateBitCast (V, NewVTy, V->getName () + " .bc" );
251+
252+ // If there is a bitsize mismatch, then we must have used a wider value to
253+ // hold the bits.
254+ assert (OriginalSize > NewSize);
255+ // For wide scalars, we can just truncate the value.
256+ if (!V->getType ()->isVectorTy ()) {
257+ Instruction *Trunc = cast<Instruction>(
258+ Builder.CreateTrunc (V, IntegerType::get (Mod->getContext (), NewSize)));
259+ return cast<Instruction>(Builder.CreateBitCast (Trunc, NewVTy));
260+ }
261+
262+ // For wider vectors, we must strip the MSBs to convert back to the original
263+ // type.
264+ VectorType *ExpandedVT = VectorType::get (
265+ Type::getIntNTy (Mod->getContext (), NewVTy->getScalarSizeInBits ()),
266+ (OriginalSize / NewVTy->getScalarSizeInBits ()), false );
267+ Instruction *Converted =
268+ cast<Instruction>(Builder.CreateBitCast (V, ExpandedVT));
269+
270+ unsigned NarrowElementCount = NewVTy->getElementCount ().getFixedValue ();
271+ SmallVector<int , 8 > ShuffleMask (NarrowElementCount);
272+ std::iota (ShuffleMask.begin (), ShuffleMask.end (), 0 );
273+
274+ return Builder.CreateShuffleVector (Converted, ShuffleMask);
275+ }
276+
277+ bool LiveRegOptimizer::optimizeLiveType (
278+ Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
279+ SmallVector<Instruction *, 4 > Worklist;
280+ SmallPtrSet<PHINode *, 4 > PhiNodes;
281+ SmallPtrSet<Instruction *, 4 > Defs;
282+ SmallPtrSet<Instruction *, 4 > Uses;
283+
284+ Worklist.push_back (cast<Instruction>(I));
285+ while (!Worklist.empty ()) {
286+ Instruction *II = Worklist.pop_back_val ();
287+
288+ if (!Visited.insert (II).second )
289+ continue ;
290+
291+ if (!shouldReplace (II->getType ()))
292+ continue ;
293+
294+ if (PHINode *Phi = dyn_cast<PHINode>(II)) {
295+ PhiNodes.insert (Phi);
296+ // Collect all the incoming values of problematic PHI nodes.
297+ for (Value *V : Phi->incoming_values ()) {
298+ // Repeat the collection process for newly found PHI nodes.
299+ if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
300+ if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
301+ Worklist.push_back (OpPhi);
302+ continue ;
303+ }
304+
305+ Instruction *IncInst = dyn_cast<Instruction>(V);
306+ // Other incoming value types (e.g. vector literals) are unhandled
307+ if (!IncInst && !isa<ConstantAggregateZero>(V))
308+ return false ;
309+
310+ // Collect all other incoming values for coercion.
311+ if (IncInst)
312+ Defs.insert (IncInst);
313+ }
314+ }
315+
316+ // Collect all relevant uses.
317+ for (User *V : II->users ()) {
318+ // Repeat the collection process for problematic PHI nodes.
319+ if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
320+ if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
321+ Worklist.push_back (OpPhi);
322+ continue ;
323+ }
324+
325+ Instruction *UseInst = cast<Instruction>(V);
326+ // Collect all uses of PHINodes and any use the crosses BB boundaries.
327+ if (UseInst->getParent () != II->getParent () || isa<PHINode>(II)) {
328+ Uses.insert (UseInst);
329+ if (!Defs.count (II) && !isa<PHINode>(II)) {
330+ Defs.insert (II);
331+ }
332+ }
333+ }
334+ }
335+
336+ // Coerce and track the defs.
337+ for (Instruction *D : Defs) {
338+ if (!ValMap.contains (D)) {
339+ BasicBlock::iterator InsertPt = std::next (D->getIterator ());
340+ Value *ConvertVal = convertToOptType (D, InsertPt);
341+ assert (ConvertVal);
342+ ValMap[D] = ConvertVal;
343+ }
344+ }
345+
346+ // Construct new-typed PHI nodes.
347+ for (PHINode *Phi : PhiNodes) {
348+ ValMap[Phi] = PHINode::Create (calculateConvertType (Phi->getType ()),
349+ Phi->getNumIncomingValues (),
350+ Phi->getName () + " .tc" , Phi->getIterator ());
351+ }
352+
353+ // Connect all the PHI nodes with their new incoming values.
354+ for (PHINode *Phi : PhiNodes) {
355+ PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
356+ bool MissingIncVal = false ;
357+ for (int I = 0 , E = Phi->getNumIncomingValues (); I < E; I++) {
358+ Value *IncVal = Phi->getIncomingValue (I);
359+ if (isa<ConstantAggregateZero>(IncVal)) {
360+ Type *NewType = calculateConvertType (Phi->getType ());
361+ NewPhi->addIncoming (ConstantInt::get (NewType, 0 , false ),
362+ Phi->getIncomingBlock (I));
363+ } else if (ValMap.contains (IncVal))
364+ NewPhi->addIncoming (ValMap[IncVal], Phi->getIncomingBlock (I));
365+ else
366+ MissingIncVal = true ;
367+ }
368+ Instruction *DeadInst = Phi;
369+ if (MissingIncVal) {
370+ DeadInst = cast<Instruction>(ValMap[Phi]);
371+ // Do not use the dead phi
372+ ValMap[Phi] = Phi;
373+ }
374+ DeadInsts.emplace_back (DeadInst);
375+ }
376+ // Coerce back to the original type and replace the uses.
377+ for (Instruction *U : Uses) {
378+ // Replace all converted operands for a use.
379+ for (auto [OpIdx, Op] : enumerate(U->operands ())) {
380+ if (ValMap.contains (Op)) {
381+ Value *NewVal = nullptr ;
382+ if (BBUseValMap.contains (U->getParent ()) &&
383+ BBUseValMap[U->getParent ()].contains (ValMap[Op]))
384+ NewVal = BBUseValMap[U->getParent ()][ValMap[Op]];
385+ else {
386+ BasicBlock::iterator InsertPt = U->getParent ()->getFirstNonPHIIt ();
387+ NewVal =
388+ convertFromOptType (Op->getType (), cast<Instruction>(ValMap[Op]),
389+ InsertPt, U->getParent ());
390+ BBUseValMap[U->getParent ()][ValMap[Op]] = NewVal;
391+ }
392+ assert (NewVal);
393+ U->setOperand (OpIdx, NewVal);
394+ }
395+ }
396+ }
397+
398+ return true ;
399+ }
400+
113401bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad (LoadInst &LI) const {
114402 unsigned AS = LI.getPointerAddressSpace ();
115403 // Skip non-constant address space.
@@ -119,7 +407,7 @@ bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
119407 // Skip non-simple loads.
120408 if (!LI.isSimple ())
121409 return false ;
122- auto *Ty = LI.getType ();
410+ Type *Ty = LI.getType ();
123411 // Skip aggregate types.
124412 if (Ty->isAggregateType ())
125413 return false ;
@@ -181,7 +469,7 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
181469 auto *NewVal = IRB.CreateBitCast (
182470 IRB.CreateTrunc (IRB.CreateLShr (NewLd, ShAmt), IntNTy), LI.getType ());
183471 LI.replaceAllUsesWith (NewVal);
184- RecursivelyDeleteTriviallyDeadInstructions (&LI);
472+ DeadInsts. emplace_back (&LI);
185473
186474 return true ;
187475}
0 commit comments