@@ -38,6 +38,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
3838 lower (MI, Mapping, WaterfallSgprs);
3939}
4040
41+ void RegBankLegalizeHelper::splitLoad (MachineInstr &MI,
42+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
43+ MachineFunction &MF = B.getMF ();
44+ assert (MI.getNumMemOperands () == 1 );
45+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
46+ Register Dst = MI.getOperand (0 ).getReg ();
47+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
48+ Register Base = MI.getOperand (1 ).getReg ();
49+ LLT PtrTy = MRI.getType (Base);
50+ const RegisterBank *PtrRB = MRI.getRegBankOrNull (Base);
51+ LLT OffsetTy = LLT::scalar (PtrTy.getSizeInBits ());
52+ SmallVector<Register, 4 > LoadPartRegs;
53+
54+ unsigned ByteOffset = 0 ;
55+ for (LLT PartTy : LLTBreakdown) {
56+ Register BasePlusOffset;
57+ if (ByteOffset == 0 ) {
58+ BasePlusOffset = Base;
59+ } else {
60+ auto Offset = B.buildConstant ({PtrRB, OffsetTy}, ByteOffset);
61+ BasePlusOffset = B.buildPtrAdd ({PtrRB, PtrTy}, Base, Offset).getReg (0 );
62+ }
63+ auto *OffsetMMO = MF.getMachineMemOperand (&BaseMMO, ByteOffset, PartTy);
64+ auto LoadPart = B.buildLoad ({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
65+ LoadPartRegs.push_back (LoadPart.getReg (0 ));
66+ ByteOffset += PartTy.getSizeInBytes ();
67+ }
68+
69+ if (!MergeTy.isValid ()) {
70+ // Loads are of same size, concat or merge them together.
71+ B.buildMergeLikeInstr (Dst, LoadPartRegs);
72+ } else {
73+ // Loads are not all of same size, need to unmerge them to smaller pieces
74+ // of MergeTy type, then merge pieces to Dst.
75+ SmallVector<Register, 4 > MergeTyParts;
76+ for (Register Reg : LoadPartRegs) {
77+ if (MRI.getType (Reg) == MergeTy) {
78+ MergeTyParts.push_back (Reg);
79+ } else {
80+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, Reg);
81+ for (unsigned i = 0 ; i < Unmerge->getNumOperands () - 1 ; ++i)
82+ MergeTyParts.push_back (Unmerge.getReg (i));
83+ }
84+ }
85+ B.buildMergeLikeInstr (Dst, MergeTyParts);
86+ }
87+ MI.eraseFromParent ();
88+ }
89+
90+ void RegBankLegalizeHelper::widenLoad (MachineInstr &MI, LLT WideTy,
91+ LLT MergeTy) {
92+ MachineFunction &MF = B.getMF ();
93+ assert (MI.getNumMemOperands () == 1 );
94+ MachineMemOperand &BaseMMO = **MI.memoperands_begin ();
95+ Register Dst = MI.getOperand (0 ).getReg ();
96+ const RegisterBank *DstRB = MRI.getRegBankOrNull (Dst);
97+ Register Base = MI.getOperand (1 ).getReg ();
98+
99+ MachineMemOperand *WideMMO = MF.getMachineMemOperand (&BaseMMO, 0 , WideTy);
100+ auto WideLoad = B.buildLoad ({DstRB, WideTy}, Base, *WideMMO);
101+
102+ if (WideTy.isScalar ()) {
103+ B.buildTrunc (Dst, WideLoad);
104+ } else {
105+ SmallVector<Register, 4 > MergeTyParts;
106+ auto Unmerge = B.buildUnmerge ({DstRB, MergeTy}, WideLoad);
107+
108+ LLT DstTy = MRI.getType (Dst);
109+ unsigned NumElts = DstTy.getSizeInBits () / MergeTy.getSizeInBits ();
110+ for (unsigned i = 0 ; i < NumElts; ++i) {
111+ MergeTyParts.push_back (Unmerge.getReg (i));
112+ }
113+ B.buildMergeLikeInstr (Dst, MergeTyParts);
114+ }
115+ MI.eraseFromParent ();
116+ }
117+
41118void RegBankLegalizeHelper::lower (MachineInstr &MI,
42119 const RegBankLLTMapping &Mapping,
43120 SmallSet<Register, 4 > &WaterfallSgprs) {
@@ -116,6 +193,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
116193 MI.eraseFromParent ();
117194 break ;
118195 }
196+ case SplitLoad: {
197+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
198+ unsigned Size = DstTy.getSizeInBits ();
199+ // Even split to 128-bit loads
200+ if (Size > 128 ) {
201+ LLT B128;
202+ if (DstTy.isVector ()) {
203+ LLT EltTy = DstTy.getElementType ();
204+ B128 = LLT::fixed_vector (128 / EltTy.getSizeInBits (), EltTy);
205+ } else {
206+ B128 = LLT::scalar (128 );
207+ }
208+ if (Size / 128 == 2 )
209+ splitLoad (MI, {B128, B128});
210+ else if (Size / 128 == 4 )
211+ splitLoad (MI, {B128, B128, B128, B128});
212+ else {
213+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
214+ llvm_unreachable (" SplitLoad type not supported for MI" );
215+ }
216+ }
217+ // 64 and 32 bit load
218+ else if (DstTy == S96)
219+ splitLoad (MI, {S64, S32}, S32);
220+ else if (DstTy == V3S32)
221+ splitLoad (MI, {V2S32, S32}, S32);
222+ else if (DstTy == V6S16)
223+ splitLoad (MI, {V4S16, V2S16}, V2S16);
224+ else {
225+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
226+ llvm_unreachable (" SplitLoad type not supported for MI" );
227+ }
228+ break ;
229+ }
230+ case WidenLoad: {
231+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
232+ if (DstTy == S96)
233+ widenLoad (MI, S128);
234+ else if (DstTy == V3S32)
235+ widenLoad (MI, V4S32, S32);
236+ else if (DstTy == V6S16)
237+ widenLoad (MI, V8S16, V2S16);
238+ else {
239+ LLVM_DEBUG (dbgs () << " MI: " ; MI.dump (););
240+ llvm_unreachable (" WidenLoad type not supported for MI" );
241+ }
242+ break ;
243+ }
119244 }
120245
121246 // TODO: executeInWaterfallLoop(... WaterfallSgprs)
@@ -139,12 +264,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
139264 case Sgpr64:
140265 case Vgpr64:
141266 return LLT::scalar (64 );
267+ case SgprP1:
268+ case VgprP1:
269+ return LLT::pointer (1 , 64 );
270+ case SgprP3:
271+ case VgprP3:
272+ return LLT::pointer (3 , 32 );
273+ case SgprP4:
274+ case VgprP4:
275+ return LLT::pointer (4 , 64 );
276+ case SgprP5:
277+ case VgprP5:
278+ return LLT::pointer (5 , 32 );
142279 case SgprV4S32:
143280 case VgprV4S32:
144281 case UniInVgprV4S32:
145282 return LLT::fixed_vector (4 , 32 );
146- case VgprP1:
147- return LLT::pointer (1 , 64 );
283+ default :
284+ return LLT ();
285+ }
286+ }
287+
288+ LLT RegBankLegalizeHelper::getBTyFromID (RegBankLLTMapingApplyID ID, LLT Ty) {
289+ switch (ID) {
290+ case SgprB32:
291+ case VgprB32:
292+ case UniInVgprB32:
293+ if (Ty == LLT::scalar (32 ) || Ty == LLT::fixed_vector (2 , 16 ) ||
294+ Ty == LLT::pointer (3 , 32 ) || Ty == LLT::pointer (5 , 32 ) ||
295+ Ty == LLT::pointer (6 , 32 ))
296+ return Ty;
297+ return LLT ();
298+ case SgprB64:
299+ case VgprB64:
300+ case UniInVgprB64:
301+ if (Ty == LLT::scalar (64 ) || Ty == LLT::fixed_vector (2 , 32 ) ||
302+ Ty == LLT::fixed_vector (4 , 16 ) || Ty == LLT::pointer (0 , 64 ) ||
303+ Ty == LLT::pointer (1 , 64 ) || Ty == LLT::pointer (4 , 64 ))
304+ return Ty;
305+ return LLT ();
306+ case SgprB96:
307+ case VgprB96:
308+ case UniInVgprB96:
309+ if (Ty == LLT::scalar (96 ) || Ty == LLT::fixed_vector (3 , 32 ) ||
310+ Ty == LLT::fixed_vector (6 , 16 ))
311+ return Ty;
312+ return LLT ();
313+ case SgprB128:
314+ case VgprB128:
315+ case UniInVgprB128:
316+ if (Ty == LLT::scalar (128 ) || Ty == LLT::fixed_vector (4 , 32 ) ||
317+ Ty == LLT::fixed_vector (2 , 64 ))
318+ return Ty;
319+ return LLT ();
320+ case SgprB256:
321+ case VgprB256:
322+ case UniInVgprB256:
323+ if (Ty == LLT::scalar (256 ) || Ty == LLT::fixed_vector (8 , 32 ) ||
324+ Ty == LLT::fixed_vector (4 , 64 ) || Ty == LLT::fixed_vector (16 , 16 ))
325+ return Ty;
326+ return LLT ();
327+ case SgprB512:
328+ case VgprB512:
329+ case UniInVgprB512:
330+ if (Ty == LLT::scalar (512 ) || Ty == LLT::fixed_vector (16 , 32 ) ||
331+ Ty == LLT::fixed_vector (8 , 64 ))
332+ return Ty;
333+ return LLT ();
148334 default :
149335 return LLT ();
150336 }
@@ -158,10 +344,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMapingApplyID ID) {
158344 case Sgpr16:
159345 case Sgpr32:
160346 case Sgpr64:
347+ case SgprP1:
348+ case SgprP3:
349+ case SgprP4:
350+ case SgprP5:
161351 case SgprV4S32:
352+ case SgprB32:
353+ case SgprB64:
354+ case SgprB96:
355+ case SgprB128:
356+ case SgprB256:
357+ case SgprB512:
162358 case UniInVcc:
163359 case UniInVgprS32:
164360 case UniInVgprV4S32:
361+ case UniInVgprB32:
362+ case UniInVgprB64:
363+ case UniInVgprB96:
364+ case UniInVgprB128:
365+ case UniInVgprB256:
366+ case UniInVgprB512:
165367 case Sgpr32Trunc:
166368 case Sgpr32AExt:
167369 case Sgpr32AExtBoolInReg:
@@ -170,7 +372,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMapingApplyID ID) {
170372 case Vgpr32:
171373 case Vgpr64:
172374 case VgprP1:
375+ case VgprP3:
376+ case VgprP4:
377+ case VgprP5:
173378 case VgprV4S32:
379+ case VgprB32:
380+ case VgprB64:
381+ case VgprB96:
382+ case VgprB128:
383+ case VgprB256:
384+ case VgprB512:
174385 return VgprRB;
175386 default :
176387 return nullptr ;
@@ -195,16 +406,40 @@ void RegBankLegalizeHelper::applyMappingDst(
195406 case Sgpr16:
196407 case Sgpr32:
197408 case Sgpr64:
409+ case SgprP1:
410+ case SgprP3:
411+ case SgprP4:
412+ case SgprP5:
198413 case SgprV4S32:
199414 case Vgpr32:
200415 case Vgpr64:
201416 case VgprP1:
417+ case VgprP3:
418+ case VgprP4:
419+ case VgprP5:
202420 case VgprV4S32: {
203421 assert (Ty == getTyFromID (MethodIDs[OpIdx]));
204422 assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
205423 break ;
206424 }
207- // uniform in vcc/vgpr: scalars and vectors
425+ // sgpr and vgpr B-types
426+ case SgprB32:
427+ case SgprB64:
428+ case SgprB96:
429+ case SgprB128:
430+ case SgprB256:
431+ case SgprB512:
432+ case VgprB32:
433+ case VgprB64:
434+ case VgprB96:
435+ case VgprB128:
436+ case VgprB256:
437+ case VgprB512: {
438+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
439+ assert (RB == getRegBankFromID (MethodIDs[OpIdx]));
440+ break ;
441+ }
442+ // uniform in vcc/vgpr: scalars, vectors and B-types
208443 case UniInVcc: {
209444 assert (Ty == S1);
210445 assert (RB == SgprRB);
@@ -223,6 +458,19 @@ void RegBankLegalizeHelper::applyMappingDst(
223458 buildReadAnyLane (B, Reg, NewVgprDst, RBI);
224459 break ;
225460 }
461+ case UniInVgprB32:
462+ case UniInVgprB64:
463+ case UniInVgprB96:
464+ case UniInVgprB128:
465+ case UniInVgprB256:
466+ case UniInVgprB512: {
467+ assert (Ty == getBTyFromID (MethodIDs[OpIdx], Ty));
468+ assert (RB == SgprRB);
469+ Register NewVgprDst = MRI.createVirtualRegister ({VgprRB, Ty});
470+ Op.setReg (NewVgprDst);
471+ AMDGPU::buildReadAnyLane (B, Reg, NewVgprDst, RBI);
472+ break ;
473+ }
226474 // sgpr trunc
227475 case Sgpr32Trunc: {
228476 assert (Ty.getSizeInBits () < 32 );
@@ -270,15 +518,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
270518 case Sgpr16:
271519 case Sgpr32:
272520 case Sgpr64:
521+ case SgprP1:
522+ case SgprP3:
523+ case SgprP4:
524+ case SgprP5:
273525 case SgprV4S32: {
274526 assert (Ty == getTyFromID (MethodIDs[i]));
275527 assert (RB == getRegBankFromID (MethodIDs[i]));
276528 break ;
277529 }
530+ // sgpr B-types
531+ case SgprB32:
532+ case SgprB64:
533+ case SgprB96:
534+ case SgprB128:
535+ case SgprB256:
536+ case SgprB512: {
537+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
538+ assert (RB == getRegBankFromID (MethodIDs[i]));
539+ break ;
540+ }
278541 // vgpr scalars, pointers and vectors
279542 case Vgpr32:
280543 case Vgpr64:
281544 case VgprP1:
545+ case VgprP3:
546+ case VgprP4:
547+ case VgprP5:
282548 case VgprV4S32: {
283549 assert (Ty == getTyFromID (MethodIDs[i]));
284550 if (RB != VgprRB) {
@@ -287,6 +553,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
287553 }
288554 break ;
289555 }
556+ // vgpr B-types
557+ case VgprB32:
558+ case VgprB64:
559+ case VgprB96:
560+ case VgprB128:
561+ case VgprB256:
562+ case VgprB512: {
563+ assert (Ty == getBTyFromID (MethodIDs[i], Ty));
564+ if (RB != VgprRB) {
565+ auto CopyToVgpr = B.buildCopy ({VgprRB, Ty}, Reg);
566+ Op.setReg (CopyToVgpr.getReg (0 ));
567+ }
568+ break ;
569+ }
290570 // sgpr and vgpr scalars with extend
291571 case Sgpr32AExt: {
292572 // Note: this ext allows S1, and it is meant to be combined away.
@@ -359,7 +639,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
359639 // We accept all types that can fit in some register class.
360640 // Uniform G_PHIs have all sgpr registers.
361641 // Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
362- if (Ty == LLT::scalar (32 )) {
642+ if (Ty == LLT::scalar (32 ) || Ty == LLT::pointer ( 4 , 64 ) ) {
363643 return ;
364644 }
365645
0 commit comments