@@ -50,6 +50,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
5050 lower (MI, Mapping, WaterfallSgprs);
5151}
5252
53+ void RegBankLegalizeHelper::splitLoad (MachineInstr &MI,
54+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
55+ MachineFunction &MF = B.getMF ();
56+ assert (MI.getNumMemOperands () == 1 );
57+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
58+ Register Dst = MI.getOperand (0 ).getReg ();
59+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
60+ Register Base = MI.getOperand (1 ).getReg ();
61+ LLT PtrTy = MRI.getType (Base);
62+ const RegisterBank *PtrRB = MRI.getRegBankOrNull (Base);
63+ LLT OffsetTy = LLT::scalar (PtrTy.getSizeInBits ());
64+ SmallVector<Register, 4 > LoadPartRegs;
65+
66+ unsigned ByteOffset = 0 ;
67+ for (LLT PartTy : LLTBreakdown) {
68+ Register BasePlusOffset;
69+ if (ByteOffset == 0 ) {
70+ BasePlusOffset = Base;
71+ } else {
72+ auto Offset = B.buildConstant ({PtrRB, OffsetTy}, ByteOffset);
73+ BasePlusOffset = B.buildPtrAdd ({PtrRB, PtrTy}, Base, Offset).getReg (0 );
74+ }
75+ auto *OffsetMMO = MF.getMachineMemOperand (&BaseMMO, ByteOffset, PartTy);
76+ auto LoadPart = B.buildLoad ({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
77+ LoadPartRegs.push_back (LoadPart.getReg (0 ));
78+ ByteOffset += PartTy.getSizeInBytes ();
79+ }
80+
81+ if (!MergeTy.isValid ()) {
82+ // Loads are of same size, concat or merge them together.
83+ B.buildMergeLikeInstr (Dst, LoadPartRegs);
84+ } else {
85+ // Loads are not all of same size, need to unmerge them to smaller pieces
86+ // of MergeTy type, then merge pieces to Dst.
87+ SmallVector<Register, 4 > MergeTyParts;
88+ for (Register Reg : LoadPartRegs) {
89+ if (MRI.getType (Reg) == MergeTy) {
90+ MergeTyParts.push_back (Reg);
91+ } else {
92+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, Reg);
93+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i)
94+ MergeTyParts.push_back (Unmerge.getReg (i));
95+ }
96+ }
97+ B.buildMergeLikeInstr (Dst, MergeTyParts);
98+ }
99+ MI.eraseFromParent ();
100+ }
101+
102+ void RegBankLegalizeHelper::widenLoad (MachineInstr &MI, LLT WideTy,
103+ LLT MergeTy) {
104+ MachineFunction &MF = B.getMF ();
105+ assert (MI.getNumMemOperands () == 1 );
106+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
107+ Register Dst = MI.getOperand (0 ).getReg ();
108+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
109+ Register Base = MI.getOperand (1 ).getReg ();
110+
111+ MachineMemOperand *WideMMO = MF.getMachineMemOperand (&BaseMMO, 0 , WideTy);
112+ auto WideLoad = B.buildLoad ({DstRB, WideTy}, Base, *WideMMO);
113+
114+ if (WideTy.isScalar ()) {
115+ B.buildTrunc (Dst, WideLoad);
116+ } else {
117+ SmallVector<Register, 4 > MergeTyParts;
118+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, WideLoad);
119+
120+ LLT DstTy = MRI.getType (Dst);
121+ unsigned NumElts = DstTy.getSizeInBits () / MergeTy.getSizeInBits ();
122+ for (unsigned i = 0 ; i < NumElts; ++i) {
123+ MergeTyParts.push_back (Unmerge.getReg (i));
124+ }
125+ B.buildMergeLikeInstr (Dst, MergeTyParts);
126+ }
127+ MI.eraseFromParent ();
128+ }
129+
53130void RegBankLegalizeHelper::lower (MachineInstr &MI,
54131 const RegBankLLTMapping &Mapping,
55132 SmallSet<Register, 4 > &WaterfallSgprs) {
@@ -128,6 +205,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
128205 MI.eraseFromParent ();
129206 break ;
130207 }
208+ case SplitLoad: {
209+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
210+ unsigned Size = DstTy.getSizeInBits ();
211+ // Even split to 128-bit loads
212+ if (Size > 128 ) {
213+ LLT B128;
214+ if (DstTy.isVector ()) {
215+ LLT EltTy = DstTy.getElementType ();
216+ B128 = LLT::fixed_vector (128 / EltTy.getSizeInBits (), EltTy);
217+ } else {
218+ B128 = LLT::scalar (128 );
219+ }
220+ if (Size / 128 == 2 )
221+ splitLoad (MI, {B128, B128});
222+ else if (Size / 128 == 4 )
223+ splitLoad (MI, {B128, B128, B128, B128});
224+ else {
225+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
226+ llvm_unreachable (" SplitLoad type not supported for MI" );
227+ }
228+ }
229+ // 64 and 32 bit load
230+ else if (DstTy == S96)
231+ splitLoad (MI, {S64, S32}, S32);
232+ else if (DstTy == V3S32)
233+ splitLoad (MI, {V2S32, S32}, S32);
234+ else if (DstTy == V6S16)
235+ splitLoad (MI, {V4S16, V2S16}, V2S16);
236+ else {
237+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
238+ llvm_unreachable (" SplitLoad type not supported for MI" );
239+ }
240+ break ;
241+ }
242+ case WidenLoad: {
243+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
244+ if (DstTy == S96)
245+ widenLoad (MI, S128);
246+ else if (DstTy == V3S32)
247+ widenLoad (MI, V4S32, S32);
248+ else if (DstTy == V6S16)
249+ widenLoad (MI, V8S16, V2S16);
250+ else {
251+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
252+ llvm_unreachable (" WidenLoad type not supported for MI" );
253+ }
254+ break ;
255+ }
131256 }
132257
133258 // TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -151,12 +276,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
151276 case Sgpr64:
152277 case Vgpr64:
153278 return LLT::scalar (64 );
279+ case SgprP1:
280+ case VgprP1:
281+ return LLT::pointer (1 , 64 );
282+ case SgprP3:
283+ case VgprP3:
284+ return LLT::pointer (3 , 32 );
285+ case SgprP4:
286+ case VgprP4:
287+ return LLT::pointer (4 , 64 );
288+ case SgprP5:
289+ case VgprP5:
290+ return LLT::pointer (5 , 32 );
154291 case SgprV4S32:
155292 case VgprV4S32:
156293 case UniInVgprV4S32:
157294 return LLT::fixed_vector (4 , 32 );
158- case VgprP1:
159- return LLT::pointer (1 , 64 );
295+ default :
296+ return LLT ();
297+ }
298+ }
299+
300+ LLT RegBankLegalizeHelper::getBTyFromID (RegBankLLTMappingApplyID ID, LLT Ty) {
301+ switch (ID) {
302+ case SgprB32:
303+ case VgprB32:
304+ case UniInVgprB32:
305+ if (Ty == LLT::scalar (32 ) || Ty == LLT::fixed_vector (2 , 16 ) ||
306+ Ty == LLT::pointer (3 , 32 ) || Ty == LLT::pointer (5 , 32 ) ||
307+ Ty == LLT::pointer (6 , 32 ))
308+ return Ty;
309+ return LLT ();
310+ case SgprB64:
311+ case VgprB64:
312+ case UniInVgprB64:
313+ if (Ty == LLT::scalar (64 ) || Ty == LLT::fixed_vector (2 , 32 ) ||
314+ Ty == LLT::fixed_vector (4 , 16 ) || Ty == LLT::pointer (0 , 64 ) ||
315+ Ty == LLT::pointer (1 , 64 ) || Ty == LLT::pointer (4 , 64 ))
316+ return Ty;
317+ return LLT ();
318+ case SgprB96:
319+ case VgprB96:
320+ case UniInVgprB96:
321+ if (Ty == LLT::scalar (96 ) || Ty == LLT::fixed_vector (3 , 32 ) ||
322+ Ty == LLT::fixed_vector (6 , 16 ))
323+ return Ty;
324+ return LLT ();
325+ case SgprB128:
326+ case VgprB128:
327+ case UniInVgprB128:
328+ if (Ty == LLT::scalar (128 ) || Ty == LLT::fixed_vector (4 , 32 ) ||
329+ Ty == LLT::fixed_vector (2 , 64 ))
330+ return Ty;
331+ return LLT ();
332+ case SgprB256:
333+ case VgprB256:
334+ case UniInVgprB256:
335+ if (Ty == LLT::scalar (256 ) || Ty == LLT::fixed_vector (8 , 32 ) ||
336+ Ty == LLT::fixed_vector (4 , 64 ) || Ty == LLT::fixed_vector (16 , 16 ))
337+ return Ty;
338+ return LLT ();
339+ case SgprB512:
340+ case VgprB512:
341+ case UniInVgprB512:
342+ if (Ty == LLT::scalar (512 ) || Ty == LLT::fixed_vector (16 , 32 ) ||
343+ Ty == LLT::fixed_vector (8 , 64 ))
344+ return Ty;
345+ return LLT ();
160346 default :
161347 return LLT ();
162348 }
@@ -170,10 +356,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
170356 case Sgpr16:
171357 case Sgpr32:
172358 case Sgpr64:
359+ case SgprP1:
360+ case SgprP3:
361+ case SgprP4:
362+ case SgprP5:
173363 case SgprV4S32:
364+ case SgprB32:
365+ case SgprB64:
366+ case SgprB96:
367+ case SgprB128:
368+ case SgprB256:
369+ case SgprB512:
174370 case UniInVcc:
175371 case UniInVgprS32:
176372 case UniInVgprV4S32:
373+ case UniInVgprB32:
374+ case UniInVgprB64:
375+ case UniInVgprB96:
376+ case UniInVgprB128:
377+ case UniInVgprB256:
378+ case UniInVgprB512:
177379 case Sgpr32Trunc:
178380 case Sgpr32AExt:
179381 case Sgpr32AExtBoolInReg:
@@ -182,7 +384,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
182384 case Vgpr32:
183385 case Vgpr64:
184386 case VgprP1:
387+ case VgprP3:
388+ case VgprP4:
389+ case VgprP5:
185390 case VgprV4S32:
391+ case VgprB32:
392+ case VgprB64:
393+ case VgprB96:
394+ case VgprB128:
395+ case VgprB256:
396+ case VgprB512:
186397 return VgprRB;
187398 default :
188399 return nullptr ;
@@ -207,16 +418,40 @@ void RegBankLegalizeHelper::applyMappingDst(
207418 case Sgpr16:
208419 case Sgpr32:
209420 case Sgpr64:
421+ case SgprP1:
422+ case SgprP3:
423+ case SgprP4:
424+ case SgprP5:
210425 case SgprV4S32:
211426 case Vgpr32:
212427 case Vgpr64:
213428 case VgprP1:
429+ case VgprP3:
430+ case VgprP4:
431+ case VgprP5:
214432 case VgprV4S32: {
215433 assert (Ty == getTyFromID (MethodIDs[OpIdx]));
216434 assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
217435 break ;
218436 }
219- // uniform in vcc/vgpr: scalars and vectors
437+ // sgpr and vgpr B-types
438+ case SgprB32:
439+ case SgprB64:
440+ case SgprB96:
441+ case SgprB128:
442+ case SgprB256:
443+ case SgprB512:
444+ case VgprB32:
445+ case VgprB64:
446+ case VgprB96:
447+ case VgprB128:
448+ case VgprB256:
449+ case VgprB512: {
450+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
451+ assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
452+ break ;
453+ }
454+ // uniform in vcc/vgpr: scalars, vectors and B-types
220455 case UniInVcc: {
221456 assert (Ty == S1);
222457 assert (RB == SgprRB);
@@ -236,6 +471,19 @@ void RegBankLegalizeHelper::applyMappingDst(
236471 buildReadAnyLane (B, Reg, NewVgprDst, RBI);
237472 break ;
238473 }
474+ case UniInVgprB32:
475+ case UniInVgprB64:
476+ case UniInVgprB96:
477+ case UniInVgprB128:
478+ case UniInVgprB256:
479+ case UniInVgprB512: {
480+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
481+ assert (RB == SgprRB);
482+ Register NewVgprDst = MRI.createVirtualRegister ({VgprRB, Ty});
483+ Op.setReg (NewVgprDst);
484+ AMDGPU::buildReadAnyLane (B, Reg, NewVgprDst, RBI);
485+ break ;
486+ }
239487 // sgpr trunc
240488 case Sgpr32Trunc: {
241489 assert (Ty.getSizeInBits () < 32 );
@@ -284,15 +532,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
284532 case Sgpr16:
285533 case Sgpr32:
286534 case Sgpr64:
535+ case SgprP1:
536+ case SgprP3:
537+ case SgprP4:
538+ case SgprP5:
287539 case SgprV4S32: {
288540 assert (Ty == getTyFromID (MethodIDs[i]));
289541 assert (RB == getRegBankFromID (MethodIDs[i]));
290542 break ;
291543 }
544+ // sgpr B-types
545+ case SgprB32:
546+ case SgprB64:
547+ case SgprB96:
548+ case SgprB128:
549+ case SgprB256:
550+ case SgprB512: {
551+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
552+ assert (RB == getRegBankFromID (MethodIDs[i]));
553+ break ;
554+ }
292555 // vgpr scalars, pointers and vectors
293556 case Vgpr32:
294557 case Vgpr64:
295558 case VgprP1:
559+ case VgprP3:
560+ case VgprP4:
561+ case VgprP5:
296562 case VgprV4S32: {
297563 assert (Ty == getTyFromID (MethodIDs[i]));
298564 if (RB != VgprRB) {
@@ -301,6 +567,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
301567 }
302568 break ;
303569 }
570+ // vgpr B-types
571+ case VgprB32:
572+ case VgprB64:
573+ case VgprB96:
574+ case VgprB128:
575+ case VgprB256:
576+ case VgprB512: {
577+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
578+ if (RB != VgprRB) {
579+ auto CopyToVgpr = B.buildCopy ({VgprRB, Ty}, Reg);
580+ Op.setReg (CopyToVgpr.getReg (0 ));
581+ }
582+ break ;
583+ }
304584 // sgpr and vgpr scalars with extend
305585 case Sgpr32AExt: {
306586 // Note: this ext allows S1, and it is meant to be combined away.
@@ -373,7 +653,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
373653 // We accept all types that can fit in some register class.
374654 // Uniform G_PHIs have all sgpr registers.
375655 // Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
376- if (Ty == LLT::scalar (32 )) {
656+ if (Ty == LLT::scalar (32 ) || Ty == LLT::pointer ( 4 , 64 ) ) {
377657 return ;
378658 }
379659
0 commit comments