Skip to content

Commit 22d2e47

Browse files
committed
FPInfo: AMDGPURegBankLegalize
1 parent 33062c2 commit 22d2e47

File tree

4 files changed

+71
-71
lines changed

4 files changed

+71
-71
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "AMDGPUGlobalISelUtils.h"
2222
#include "AMDGPURegBankLegalizeHelper.h"
2323
#include "GCNSubtarget.h"
24+
#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
2425
#include "llvm/CodeGen/GlobalISel/CSEInfo.h"
2526
#include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h"
2627
#include "llvm/CodeGen/MachineFunctionPass.h"
@@ -106,10 +107,10 @@ class AMDGPURegBankLegalizeCombiner {
106107
const RegisterBank *VgprRB;
107108
const RegisterBank *VccRB;
108109

109-
static constexpr LLT S1 = LLT::scalar(1);
110-
static constexpr LLT S16 = LLT::scalar(16);
111-
static constexpr LLT S32 = LLT::scalar(32);
112-
static constexpr LLT S64 = LLT::scalar(64);
110+
static constexpr LLT I1 = LLT::integer(1);
111+
static constexpr LLT I16 = LLT::integer(16);
112+
static constexpr LLT I32 = LLT::integer(32);
113+
static constexpr LLT I64 = LLT::integer(64);
113114

114115
public:
115116
AMDGPURegBankLegalizeCombiner(MachineIRBuilder &B, const SIRegisterInfo &TRI,
@@ -125,7 +126,7 @@ class AMDGPURegBankLegalizeCombiner {
125126
return true;
126127

127128
const TargetRegisterClass *RC = MRI.getRegClassOrNull(Reg);
128-
return RC && TRI.isSGPRClass(RC) && MRI.getType(Reg) == LLT::scalar(1);
129+
return RC && TRI.isSGPRClass(RC) && MRI.getType(Reg).isScalar(1);
129130
}
130131

131132
void cleanUpAfterCombine(MachineInstr &MI, MachineInstr *Optional0) {
@@ -156,13 +157,13 @@ class AMDGPURegBankLegalizeCombiner {
156157
// %Dst:lane-mask(s1) = G_AMDGPU_COPY_VCC_SCC %TruncS32Src:sgpr(s32)
157158
if (isLaneMask(Dst) && MRI.getRegBankOrNull(Src) == SgprRB) {
158159
auto [Trunc, TruncS32Src] = tryMatch(Src, AMDGPU::G_TRUNC);
159-
assert(Trunc && MRI.getType(TruncS32Src) == S32 &&
160+
assert(Trunc && MRI.getType(TruncS32Src) == I32 &&
160161
"sgpr S1 must be result of G_TRUNC of sgpr S32");
161162

162163
B.setInstr(MI);
163164
// Ensure that truncated bits in BoolSrc are 0.
164-
auto One = B.buildConstant({SgprRB, S32}, 1);
165-
auto BoolSrc = B.buildAnd({SgprRB, S32}, TruncS32Src, One);
165+
auto One = B.buildConstant({SgprRB, I32}, 1);
166+
auto BoolSrc = B.buildAnd({SgprRB, I32}, TruncS32Src, One);
166167
B.buildInstr(AMDGPU::G_AMDGPU_COPY_VCC_SCC, {Dst}, {BoolSrc});
167168
cleanUpAfterCombine(MI, Trunc);
168169
return;
@@ -192,7 +193,7 @@ class AMDGPURegBankLegalizeCombiner {
192193
// %Dst = G_... %TruncSrc
193194
Register Dst = MI.getOperand(0).getReg();
194195
Register Src = MI.getOperand(1).getReg();
195-
if (MRI.getType(Src) != S1)
196+
if (MRI.getType(Src) != I1)
196197
return;
197198

198199
auto [Trunc, TruncSrc] = tryMatch(Src, AMDGPU::G_TRUNC);
@@ -210,20 +211,20 @@ class AMDGPURegBankLegalizeCombiner {
210211

211212
B.setInstr(MI);
212213

213-
if (DstTy == S32 && TruncSrcTy == S64) {
214-
auto Unmerge = B.buildUnmerge({SgprRB, S32}, TruncSrc);
214+
if (DstTy == I32 && TruncSrcTy == I64) {
215+
auto Unmerge = B.buildUnmerge({SgprRB, I32}, TruncSrc);
215216
MRI.replaceRegWith(Dst, Unmerge.getReg(0));
216217
cleanUpAfterCombine(MI, Trunc);
217218
return;
218219
}
219220

220-
if (DstTy == S32 && TruncSrcTy == S16) {
221+
if (DstTy == I32 && TruncSrcTy == I16) {
221222
B.buildAnyExt(Dst, TruncSrc);
222223
cleanUpAfterCombine(MI, Trunc);
223224
return;
224225
}
225226

226-
if (DstTy == S16 && TruncSrcTy == S32) {
227+
if (DstTy == I16 && TruncSrcTy == I32) {
227228
B.buildTrunc(Dst, TruncSrc);
228229
cleanUpAfterCombine(MI, Trunc);
229230
return;
@@ -235,10 +236,9 @@ class AMDGPURegBankLegalizeCombiner {
235236

236237
// Search through MRI for virtual registers with sgpr register bank and S1 LLT.
237238
[[maybe_unused]] static Register getAnySgprS1(const MachineRegisterInfo &MRI) {
238-
const LLT S1 = LLT::scalar(1);
239239
for (unsigned i = 0; i < MRI.getNumVirtRegs(); ++i) {
240240
Register Reg = Register::index2VirtReg(i);
241-
if (MRI.def_empty(Reg) || MRI.getType(Reg) != S1)
241+
if (MRI.def_empty(Reg) || !MRI.getType(Reg).isScalar(1))
242242
continue;
243243

244244
const RegisterBank *RB = MRI.getRegBankOrNull(Reg);
@@ -306,7 +306,7 @@ bool AMDGPURegBankLegalize::runOnMachineFunction(MachineFunction &MF) {
306306
// Opcodes that support pretty much all combinations of reg banks and LLTs
307307
// (except S1). There is no point in writing rules for them.
308308
if (Opc == AMDGPU::G_BUILD_VECTOR || Opc == AMDGPU::G_UNMERGE_VALUES ||
309-
Opc == AMDGPU::G_MERGE_VALUES) {
309+
Opc == AMDGPU::G_MERGE_VALUES || Opc == AMDGPU::G_BITCAST) {
310310
RBLHelper.applyMappingTrivial(*MI);
311311
continue;
312312
}
@@ -316,7 +316,7 @@ bool AMDGPURegBankLegalize::runOnMachineFunction(MachineFunction &MF) {
316316
Opc == AMDGPU::G_IMPLICIT_DEF)) {
317317
Register Dst = MI->getOperand(0).getReg();
318318
// Non S1 types are trivially accepted.
319-
if (MRI.getType(Dst) != LLT::scalar(1)) {
319+
if (!MRI.getType(Dst).isScalar(1)) {
320320
assert(MRI.getRegBank(Dst)->getID() == AMDGPU::SGPRRegBankID);
321321
continue;
322322
}

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
215215
LLT EltTy = DstTy.getElementType();
216216
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
217217
} else {
218-
B128 = LLT::scalar(128);
218+
B128 = LLT::integer(128);
219219
}
220220
if (Size / 128 == 2)
221221
splitLoad(MI, {B128, B128});
@@ -258,42 +258,42 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
258258
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
259259
}
260260

261-
LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
261+
bool RegBankLegalizeHelper::isValidTyForID(LLT Ty, RegBankLLTMappingApplyID ID) {
262262
switch (ID) {
263263
case Vcc:
264264
case UniInVcc:
265-
return LLT::scalar(1);
265+
return Ty.isScalar(1);
266266
case Sgpr16:
267-
return LLT::scalar(16);
267+
return Ty.isScalar(16);
268268
case Sgpr32:
269269
case Sgpr32Trunc:
270270
case Sgpr32AExt:
271271
case Sgpr32AExtBoolInReg:
272272
case Sgpr32SExt:
273273
case UniInVgprS32:
274274
case Vgpr32:
275-
return LLT::scalar(32);
275+
return Ty.isScalar(32);
276276
case Sgpr64:
277277
case Vgpr64:
278-
return LLT::scalar(64);
278+
return Ty.isScalar(64);
279279
case SgprP1:
280280
case VgprP1:
281-
return LLT::pointer(1, 64);
281+
return Ty == LLT::pointer(1, 64);
282282
case SgprP3:
283283
case VgprP3:
284-
return LLT::pointer(3, 32);
284+
return Ty == LLT::pointer(3, 32);
285285
case SgprP4:
286286
case VgprP4:
287-
return LLT::pointer(4, 64);
287+
return Ty == LLT::pointer(4, 64);
288288
case SgprP5:
289289
case VgprP5:
290-
return LLT::pointer(5, 32);
290+
return Ty == LLT::pointer(5, 32);
291291
case SgprV4S32:
292292
case VgprV4S32:
293293
case UniInVgprV4S32:
294-
return LLT::fixed_vector(4, 32);
294+
return Ty.isFixedVector(4, 32);
295295
default:
296-
return LLT();
296+
return Ty == LLT();
297297
}
298298
}
299299

@@ -302,45 +302,45 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
302302
case SgprB32:
303303
case VgprB32:
304304
case UniInVgprB32:
305-
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
305+
if (Ty.isScalar(32) || Ty.isFixedVector(2, 16) ||
306306
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
307307
Ty == LLT::pointer(6, 32))
308308
return Ty;
309309
return LLT();
310310
case SgprB64:
311311
case VgprB64:
312312
case UniInVgprB64:
313-
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
314-
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
313+
if (Ty.isScalar(64) || Ty.isFixedVector(2, 32) ||
314+
Ty.isFixedVector(4, 16) || Ty == LLT::pointer(0, 64) ||
315315
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
316316
return Ty;
317317
return LLT();
318318
case SgprB96:
319319
case VgprB96:
320320
case UniInVgprB96:
321-
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
322-
Ty == LLT::fixed_vector(6, 16))
321+
if (Ty.isScalar(96) || Ty.isFixedVector(3, 32) ||
322+
Ty.isFixedVector(6, 16))
323323
return Ty;
324324
return LLT();
325325
case SgprB128:
326326
case VgprB128:
327327
case UniInVgprB128:
328-
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
329-
Ty == LLT::fixed_vector(2, 64))
328+
if (Ty.isScalar(128) || Ty.isFixedVector(4, 32) ||
329+
Ty.isFixedVector(2, 64))
330330
return Ty;
331331
return LLT();
332332
case SgprB256:
333333
case VgprB256:
334334
case UniInVgprB256:
335-
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
336-
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
335+
if (Ty.isScalar(256) || Ty.isFixedVector(8, 32) ||
336+
Ty.isFixedVector(4, 64) || Ty.isFixedVector(16, 16))
337337
return Ty;
338338
return LLT();
339339
case SgprB512:
340340
case VgprB512:
341341
case UniInVgprB512:
342-
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
343-
Ty == LLT::fixed_vector(8, 64))
342+
if (Ty.isScalar(512) || Ty.isFixedVector(16, 32) ||
343+
Ty.isFixedVector(8, 64))
344344
return Ty;
345345
return LLT();
346346
default:
@@ -430,7 +430,7 @@ void RegBankLegalizeHelper::applyMappingDst(
430430
case VgprP4:
431431
case VgprP5:
432432
case VgprV4S32: {
433-
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
433+
assert(isValidTyForID(Ty, MethodIDs[OpIdx]));
434434
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
435435
break;
436436
}
@@ -464,7 +464,7 @@ void RegBankLegalizeHelper::applyMappingDst(
464464
}
465465
case UniInVgprS32:
466466
case UniInVgprV4S32: {
467-
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
467+
assert(isValidTyForID(Ty, MethodIDs[OpIdx]));
468468
assert(RB == SgprRB);
469469
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
470470
Op.setReg(NewVgprDst);
@@ -537,7 +537,7 @@ void RegBankLegalizeHelper::applyMappingSrc(
537537
case SgprP4:
538538
case SgprP5:
539539
case SgprV4S32: {
540-
assert(Ty == getTyFromID(MethodIDs[i]));
540+
assert(isValidTyForID(Ty, MethodIDs[i]));
541541
assert(RB == getRegBankFromID(MethodIDs[i]));
542542
break;
543543
}
@@ -560,7 +560,7 @@ void RegBankLegalizeHelper::applyMappingSrc(
560560
case VgprP4:
561561
case VgprP5:
562562
case VgprV4S32: {
563-
assert(Ty == getTyFromID(MethodIDs[i]));
563+
assert(isValidTyForID(Ty, MethodIDs[i]));
564564
if (RB != VgprRB) {
565565
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
566566
Op.setReg(CopyToVgpr.getReg(0));
@@ -619,7 +619,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
619619
Register Dst = MI.getOperand(0).getReg();
620620
LLT Ty = MRI.getType(Dst);
621621

622-
if (Ty == LLT::scalar(1) && MUI.isUniform(Dst)) {
622+
if (Ty.isScalar(1) && MUI.isUniform(Dst)) {
623623
B.setInsertPt(*MI.getParent(), MI.getParent()->getFirstNonPHI());
624624

625625
Register NewDst = MRI.createVirtualRegister(SgprRB_S32);
@@ -644,7 +644,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
644644
// ALL divergent i1 phis should be already lowered and inst-selected into PHI
645645
// with sgpr reg class and S1 LLT.
646646
// Note: this includes divergent phis that don't require lowering.
647-
if (Ty == LLT::scalar(1) && MUI.isDivergent(Dst)) {
647+
if (Ty.isScalar(1) && MUI.isDivergent(Dst)) {
648648
LLVM_DEBUG(dbgs() << "Divergent S1 G_PHI: "; MI.dump(););
649649
llvm_unreachable("Make sure to run AMDGPUGlobalISelDivergenceLowering "
650650
"before RegBankLegalize to lower lane mask(vcc) phis");
@@ -653,7 +653,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
653653
// We accept all types that can fit in some register class.
654654
// Uniform G_PHIs have all sgpr registers.
655655
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
656-
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
656+
if (Ty.isScalar(32) || Ty == LLT::pointer(4, 64)) {
657657
return;
658658
}
659659

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class RegBankLegalizeHelper {
8888
iterator_range<MachineBasicBlock::iterator> Range,
8989
SmallSet<Register, 4> &SgprOperandRegs);
9090

91-
LLT getTyFromID(RegBankLLTMappingApplyID ID);
91+
bool isValidTyForID(LLT Ty, RegBankLLTMappingApplyID ID);
9292
LLT getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty);
9393

9494
const RegisterBank *getRegBankFromID(RegBankLLTMappingApplyID ID);

0 commit comments

Comments
 (0)