Skip to content

Commit ef9a02c

Browse files
authored
[CodeGen] Use VirtRegOrUnit where appropriate (NFCI) (#167730)
Use it in `printVRegOrUnit()`, `getPressureSets()`/`PSetIterator`, and in functions/classes dealing with register pressure. Static type checking revealed several bugs, mainly in MachinePipeliner. I'm not very familiar with this pass, so I left a bunch of FIXMEs. There is one bug in `findUseBetween()` in RegisterPressure.cpp, also annotated with a FIXME.
1 parent a25daa3 commit ef9a02c

13 files changed

+313
-261
lines changed

llvm/include/llvm/CodeGen/MachineRegisterInfo.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -634,10 +634,9 @@ class MachineRegisterInfo {
634634
/// function. Writing to a constant register has no effect.
635635
LLVM_ABI bool isConstantPhysReg(MCRegister PhysReg) const;
636636

637-
/// Get an iterator over the pressure sets affected by the given physical or
638-
/// virtual register. If RegUnit is physical, it must be a register unit (from
639-
/// MCRegUnitIterator).
640-
PSetIterator getPressureSets(Register RegUnit) const;
637+
/// Get an iterator over the pressure sets affected by the virtual register
638+
/// or register unit.
639+
PSetIterator getPressureSets(VirtRegOrUnit VRegOrUnit) const;
641640

642641
//===--------------------------------------------------------------------===//
643642
// Virtual Register Info
@@ -1249,15 +1248,16 @@ class PSetIterator {
12491248
public:
12501249
PSetIterator() = default;
12511250

1252-
PSetIterator(Register RegUnit, const MachineRegisterInfo *MRI) {
1251+
PSetIterator(VirtRegOrUnit VRegOrUnit, const MachineRegisterInfo *MRI) {
12531252
const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
1254-
if (RegUnit.isVirtual()) {
1255-
const TargetRegisterClass *RC = MRI->getRegClass(RegUnit);
1253+
if (VRegOrUnit.isVirtualReg()) {
1254+
const TargetRegisterClass *RC =
1255+
MRI->getRegClass(VRegOrUnit.asVirtualReg());
12561256
PSet = TRI->getRegClassPressureSets(RC);
12571257
Weight = TRI->getRegClassWeight(RC).RegWeight;
12581258
} else {
1259-
PSet = TRI->getRegUnitPressureSets(RegUnit);
1260-
Weight = TRI->getRegUnitWeight(RegUnit);
1259+
PSet = TRI->getRegUnitPressureSets(VRegOrUnit.asMCRegUnit());
1260+
Weight = TRI->getRegUnitWeight(VRegOrUnit.asMCRegUnit());
12611261
}
12621262
if (*PSet == -1)
12631263
PSet = nullptr;
@@ -1278,8 +1278,8 @@ class PSetIterator {
12781278
};
12791279

12801280
inline PSetIterator
1281-
MachineRegisterInfo::getPressureSets(Register RegUnit) const {
1282-
return PSetIterator(RegUnit, this);
1281+
MachineRegisterInfo::getPressureSets(VirtRegOrUnit VRegOrUnit) const {
1282+
return PSetIterator(VRegOrUnit, this);
12831283
}
12841284

12851285
} // end namespace llvm

llvm/include/llvm/CodeGen/Register.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ class VirtRegOrUnit {
206206
constexpr bool operator==(const VirtRegOrUnit &Other) const {
207207
return VRegOrUnit == Other.VRegOrUnit;
208208
}
209+
210+
constexpr bool operator<(const VirtRegOrUnit &Other) const {
211+
return VRegOrUnit < Other.VRegOrUnit;
212+
}
209213
};
210214

211215
} // namespace llvm

llvm/include/llvm/CodeGen/RegisterPressure.h

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ class MachineRegisterInfo;
3737
class RegisterClassInfo;
3838

3939
struct VRegMaskOrUnit {
40-
Register RegUnit; ///< Virtual register or register unit.
40+
VirtRegOrUnit VRegOrUnit;
4141
LaneBitmask LaneMask;
4242

43-
VRegMaskOrUnit(Register RegUnit, LaneBitmask LaneMask)
44-
: RegUnit(RegUnit), LaneMask(LaneMask) {}
43+
VRegMaskOrUnit(VirtRegOrUnit VRegOrUnit, LaneBitmask LaneMask)
44+
: VRegOrUnit(VRegOrUnit), LaneMask(LaneMask) {}
4545
};
4646

4747
/// Base class for register pressure results.
@@ -157,7 +157,7 @@ class PressureDiff {
157157
const_iterator begin() const { return &PressureChanges[0]; }
158158
const_iterator end() const { return &PressureChanges[MaxPSets]; }
159159

160-
LLVM_ABI void addPressureChange(Register RegUnit, bool IsDec,
160+
LLVM_ABI void addPressureChange(VirtRegOrUnit VRegOrUnit, bool IsDec,
161161
const MachineRegisterInfo *MRI);
162162

163163
LLVM_ABI void dump(const TargetRegisterInfo &TRI) const;
@@ -279,25 +279,25 @@ class LiveRegSet {
279279
RegSet Regs;
280280
unsigned NumRegUnits = 0u;
281281

282-
unsigned getSparseIndexFromReg(Register Reg) const {
283-
if (Reg.isVirtual())
284-
return Reg.virtRegIndex() + NumRegUnits;
285-
assert(Reg < NumRegUnits);
286-
return Reg.id();
282+
unsigned getSparseIndexFromVirtRegOrUnit(VirtRegOrUnit VRegOrUnit) const {
283+
if (VRegOrUnit.isVirtualReg())
284+
return VRegOrUnit.asVirtualReg().virtRegIndex() + NumRegUnits;
285+
assert(VRegOrUnit.asMCRegUnit() < NumRegUnits);
286+
return VRegOrUnit.asMCRegUnit();
287287
}
288288

289-
Register getRegFromSparseIndex(unsigned SparseIndex) const {
289+
VirtRegOrUnit getVirtRegOrUnitFromSparseIndex(unsigned SparseIndex) const {
290290
if (SparseIndex >= NumRegUnits)
291-
return Register::index2VirtReg(SparseIndex - NumRegUnits);
292-
return Register(SparseIndex);
291+
return VirtRegOrUnit(Register::index2VirtReg(SparseIndex - NumRegUnits));
292+
return VirtRegOrUnit(SparseIndex);
293293
}
294294

295295
public:
296296
LLVM_ABI void clear();
297297
LLVM_ABI void init(const MachineRegisterInfo &MRI);
298298

299-
LaneBitmask contains(Register Reg) const {
300-
unsigned SparseIndex = getSparseIndexFromReg(Reg);
299+
LaneBitmask contains(VirtRegOrUnit VRegOrUnit) const {
300+
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(VRegOrUnit);
301301
RegSet::const_iterator I = Regs.find(SparseIndex);
302302
if (I == Regs.end())
303303
return LaneBitmask::getNone();
@@ -307,7 +307,7 @@ class LiveRegSet {
307307
/// Mark the \p Pair.LaneMask lanes of \p Pair.Reg as live.
308308
/// Returns the previously live lanes of \p Pair.Reg.
309309
LaneBitmask insert(VRegMaskOrUnit Pair) {
310-
unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
310+
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
311311
auto InsertRes = Regs.insert(IndexMaskPair(SparseIndex, Pair.LaneMask));
312312
if (!InsertRes.second) {
313313
LaneBitmask PrevMask = InsertRes.first->LaneMask;
@@ -320,7 +320,7 @@ class LiveRegSet {
320320
/// Clears the \p Pair.LaneMask lanes of \p Pair.Reg (mark them as dead).
321321
/// Returns the previously live lanes of \p Pair.Reg.
322322
LaneBitmask erase(VRegMaskOrUnit Pair) {
323-
unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
323+
unsigned SparseIndex = getSparseIndexFromVirtRegOrUnit(Pair.VRegOrUnit);
324324
RegSet::iterator I = Regs.find(SparseIndex);
325325
if (I == Regs.end())
326326
return LaneBitmask::getNone();
@@ -335,9 +335,9 @@ class LiveRegSet {
335335

336336
void appendTo(SmallVectorImpl<VRegMaskOrUnit> &To) const {
337337
for (const IndexMaskPair &P : Regs) {
338-
Register Reg = getRegFromSparseIndex(P.Index);
338+
VirtRegOrUnit VRegOrUnit = getVirtRegOrUnitFromSparseIndex(P.Index);
339339
if (P.LaneMask.any())
340-
To.emplace_back(Reg, P.LaneMask);
340+
To.emplace_back(VRegOrUnit, P.LaneMask);
341341
}
342342
}
343343
};
@@ -541,9 +541,11 @@ class RegPressureTracker {
541541

542542
LLVM_ABI void dump() const;
543543

544-
LLVM_ABI void increaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
544+
LLVM_ABI void increaseRegPressure(VirtRegOrUnit VRegOrUnit,
545+
LaneBitmask PreviousMask,
545546
LaneBitmask NewMask);
546-
LLVM_ABI void decreaseRegPressure(Register RegUnit, LaneBitmask PreviousMask,
547+
LLVM_ABI void decreaseRegPressure(VirtRegOrUnit VRegOrUnit,
548+
LaneBitmask PreviousMask,
547549
LaneBitmask NewMask);
548550

549551
protected:
@@ -565,9 +567,12 @@ class RegPressureTracker {
565567
discoverLiveInOrOut(VRegMaskOrUnit Pair,
566568
SmallVectorImpl<VRegMaskOrUnit> &LiveInOrOut);
567569

568-
LLVM_ABI LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const;
569-
LLVM_ABI LaneBitmask getLiveLanesAt(Register RegUnit, SlotIndex Pos) const;
570-
LLVM_ABI LaneBitmask getLiveThroughAt(Register RegUnit, SlotIndex Pos) const;
570+
LLVM_ABI LaneBitmask getLastUsedLanes(VirtRegOrUnit VRegOrUnit,
571+
SlotIndex Pos) const;
572+
LLVM_ABI LaneBitmask getLiveLanesAt(VirtRegOrUnit VRegOrUnit,
573+
SlotIndex Pos) const;
574+
LLVM_ABI LaneBitmask getLiveThroughAt(VirtRegOrUnit VRegOrUnit,
575+
SlotIndex Pos) const;
571576
};
572577

573578
LLVM_ABI void dumpRegSetPressure(ArrayRef<unsigned> SetPressure,

llvm/include/llvm/CodeGen/TargetRegisterInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ LLVM_ABI Printable printRegUnit(MCRegUnit Unit, const TargetRegisterInfo *TRI);
14501450

14511451
/// Create Printable object to print virtual registers and physical
14521452
/// registers on a \ref raw_ostream.
1453-
LLVM_ABI Printable printVRegOrUnit(unsigned VRegOrUnit,
1453+
LLVM_ABI Printable printVRegOrUnit(VirtRegOrUnit VRegOrUnit,
14541454
const TargetRegisterInfo *TRI);
14551455

14561456
/// Create Printable object to print register classes or register banks

llvm/lib/CodeGen/MachinePipeliner.cpp

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,11 @@ class HighRegisterPressureDetector {
15091509

15101510
void dumpPSet(Register Reg) const {
15111511
dbgs() << "Reg=" << printReg(Reg, TRI, 0, &MRI) << " PSet=";
1512-
for (auto PSetIter = MRI.getPressureSets(Reg); PSetIter.isValid();
1512+
// FIXME: The static_cast is a bug compensating bugs in the callers.
1513+
VirtRegOrUnit VRegOrUnit =
1514+
Reg.isVirtual() ? VirtRegOrUnit(Reg)
1515+
: VirtRegOrUnit(static_cast<MCRegUnit>(Reg.id()));
1516+
for (auto PSetIter = MRI.getPressureSets(VRegOrUnit); PSetIter.isValid();
15131517
++PSetIter) {
15141518
dbgs() << *PSetIter << ' ';
15151519
}
@@ -1518,15 +1522,19 @@ class HighRegisterPressureDetector {
15181522

15191523
void increaseRegisterPressure(std::vector<unsigned> &Pressure,
15201524
Register Reg) const {
1521-
auto PSetIter = MRI.getPressureSets(Reg);
1525+
// FIXME: The static_cast is a bug compensating bugs in the callers.
1526+
VirtRegOrUnit VRegOrUnit =
1527+
Reg.isVirtual() ? VirtRegOrUnit(Reg)
1528+
: VirtRegOrUnit(static_cast<MCRegUnit>(Reg.id()));
1529+
auto PSetIter = MRI.getPressureSets(VRegOrUnit);
15221530
unsigned Weight = PSetIter.getWeight();
15231531
for (; PSetIter.isValid(); ++PSetIter)
15241532
Pressure[*PSetIter] += Weight;
15251533
}
15261534

15271535
void decreaseRegisterPressure(std::vector<unsigned> &Pressure,
15281536
Register Reg) const {
1529-
auto PSetIter = MRI.getPressureSets(Reg);
1537+
auto PSetIter = MRI.getPressureSets(VirtRegOrUnit(Reg));
15301538
unsigned Weight = PSetIter.getWeight();
15311539
for (; PSetIter.isValid(); ++PSetIter) {
15321540
auto &P = Pressure[*PSetIter];
@@ -1559,7 +1567,11 @@ class HighRegisterPressureDetector {
15591567
if (MI.isDebugInstr())
15601568
continue;
15611569
for (auto &Use : ROMap[&MI].Uses) {
1562-
auto Reg = Use.RegUnit;
1570+
// FIXME: The static_cast is a bug.
1571+
Register Reg =
1572+
Use.VRegOrUnit.isVirtualReg()
1573+
? Use.VRegOrUnit.asVirtualReg()
1574+
: Register(static_cast<unsigned>(Use.VRegOrUnit.asMCRegUnit()));
15631575
// Ignore the variable that appears only on one side of phi instruction
15641576
// because it's used only at the first iteration.
15651577
if (MI.isPHI() && Reg != getLoopPhiReg(MI, OrigMBB))
@@ -1609,8 +1621,14 @@ class HighRegisterPressureDetector {
16091621
Register Reg = getLoopPhiReg(*MI, OrigMBB);
16101622
UpdateTargetRegs(Reg);
16111623
} else {
1612-
for (auto &Use : ROMap.find(MI)->getSecond().Uses)
1613-
UpdateTargetRegs(Use.RegUnit);
1624+
for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
1625+
// FIXME: The static_cast is a bug.
1626+
Register Reg = Use.VRegOrUnit.isVirtualReg()
1627+
? Use.VRegOrUnit.asVirtualReg()
1628+
: Register(static_cast<unsigned>(
1629+
Use.VRegOrUnit.asMCRegUnit()));
1630+
UpdateTargetRegs(Reg);
1631+
}
16141632
}
16151633
}
16161634

@@ -1621,7 +1639,11 @@ class HighRegisterPressureDetector {
16211639
DenseMap<Register, MachineInstr *> LastUseMI;
16221640
for (MachineInstr *MI : llvm::reverse(OrderedInsts)) {
16231641
for (auto &Use : ROMap.find(MI)->getSecond().Uses) {
1624-
auto Reg = Use.RegUnit;
1642+
// FIXME: The static_cast is a bug.
1643+
Register Reg =
1644+
Use.VRegOrUnit.isVirtualReg()
1645+
? Use.VRegOrUnit.asVirtualReg()
1646+
: Register(static_cast<unsigned>(Use.VRegOrUnit.asMCRegUnit()));
16251647
if (!TargetRegs.contains(Reg))
16261648
continue;
16271649
auto [Ite, Inserted] = LastUseMI.try_emplace(Reg, MI);
@@ -1635,8 +1657,8 @@ class HighRegisterPressureDetector {
16351657
}
16361658

16371659
Instr2LastUsesTy LastUses;
1638-
for (auto &Entry : LastUseMI)
1639-
LastUses[Entry.second].insert(Entry.first);
1660+
for (auto [Reg, MI] : LastUseMI)
1661+
LastUses[MI].insert(Reg);
16401662
return LastUses;
16411663
}
16421664

@@ -1675,7 +1697,12 @@ class HighRegisterPressureDetector {
16751697
});
16761698

16771699
const auto InsertReg = [this, &CurSetPressure](RegSetTy &RegSet,
1678-
Register Reg) {
1700+
VirtRegOrUnit VRegOrUnit) {
1701+
// FIXME: The static_cast is a bug.
1702+
Register Reg =
1703+
VRegOrUnit.isVirtualReg()
1704+
? VRegOrUnit.asVirtualReg()
1705+
: Register(static_cast<unsigned>(VRegOrUnit.asMCRegUnit()));
16791706
if (!Reg.isValid() || isReservedRegister(Reg))
16801707
return;
16811708

@@ -1712,7 +1739,7 @@ class HighRegisterPressureDetector {
17121739
const unsigned Iter = I - Stage;
17131740

17141741
for (auto &Def : ROMap.find(MI)->getSecond().Defs)
1715-
InsertReg(LiveRegSets[Iter], Def.RegUnit);
1742+
InsertReg(LiveRegSets[Iter], Def.VRegOrUnit);
17161743

17171744
for (auto LastUse : LastUses[MI]) {
17181745
if (MI->isPHI()) {
@@ -2235,30 +2262,33 @@ static void computeLiveOuts(MachineFunction &MF, RegPressureTracker &RPTracker,
22352262
const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
22362263
MachineRegisterInfo &MRI = MF.getRegInfo();
22372264
SmallVector<VRegMaskOrUnit, 8> LiveOutRegs;
2238-
SmallSet<Register, 4> Uses;
2265+
SmallSet<VirtRegOrUnit, 4> Uses;
22392266
for (SUnit *SU : NS) {
22402267
const MachineInstr *MI = SU->getInstr();
22412268
if (MI->isPHI())
22422269
continue;
22432270
for (const MachineOperand &MO : MI->all_uses()) {
22442271
Register Reg = MO.getReg();
22452272
if (Reg.isVirtual())
2246-
Uses.insert(Reg);
2273+
Uses.insert(VirtRegOrUnit(Reg));
22472274
else if (MRI.isAllocatable(Reg))
2248-
Uses.insert_range(TRI->regunits(Reg.asMCReg()));
2275+
for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
2276+
Uses.insert(VirtRegOrUnit(Unit));
22492277
}
22502278
}
22512279
for (SUnit *SU : NS)
22522280
for (const MachineOperand &MO : SU->getInstr()->all_defs())
22532281
if (!MO.isDead()) {
22542282
Register Reg = MO.getReg();
22552283
if (Reg.isVirtual()) {
2256-
if (!Uses.count(Reg))
2257-
LiveOutRegs.emplace_back(Reg, LaneBitmask::getNone());
2284+
if (!Uses.count(VirtRegOrUnit(Reg)))
2285+
LiveOutRegs.emplace_back(VirtRegOrUnit(Reg),
2286+
LaneBitmask::getNone());
22582287
} else if (MRI.isAllocatable(Reg)) {
22592288
for (MCRegUnit Unit : TRI->regunits(Reg.asMCReg()))
2260-
if (!Uses.count(Unit))
2261-
LiveOutRegs.emplace_back(Unit, LaneBitmask::getNone());
2289+
if (!Uses.count(VirtRegOrUnit(Unit)))
2290+
LiveOutRegs.emplace_back(VirtRegOrUnit(Unit),
2291+
LaneBitmask::getNone());
22622292
}
22632293
}
22642294
RPTracker.addLiveRegs(LiveOutRegs);

llvm/lib/CodeGen/MachineScheduler.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,10 +1580,10 @@ updateScheduledPressure(const SUnit *SU,
15801580
/// instruction.
15811581
void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
15821582
for (const VRegMaskOrUnit &P : LiveUses) {
1583-
Register Reg = P.RegUnit;
15841583
/// FIXME: Currently assuming single-use physregs.
1585-
if (!Reg.isVirtual())
1584+
if (!P.VRegOrUnit.isVirtualReg())
15861585
continue;
1586+
Register Reg = P.VRegOrUnit.asVirtualReg();
15871587

15881588
if (ShouldTrackLaneMasks) {
15891589
// If the register has just become live then other uses won't change
@@ -1599,7 +1599,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
15991599
continue;
16001600

16011601
PressureDiff &PDiff = getPressureDiff(&SU);
1602-
PDiff.addPressureChange(Reg, Decrement, &MRI);
1602+
PDiff.addPressureChange(VirtRegOrUnit(Reg), Decrement, &MRI);
16031603
if (llvm::any_of(PDiff, [](const PressureChange &Change) {
16041604
return Change.isValid();
16051605
}))
@@ -1611,7 +1611,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
16111611
}
16121612
} else {
16131613
assert(P.LaneMask.any());
1614-
LLVM_DEBUG(dbgs() << " LiveReg: " << printVRegOrUnit(Reg, TRI) << "\n");
1614+
LLVM_DEBUG(dbgs() << " LiveReg: " << printReg(Reg, TRI) << "\n");
16151615
// This may be called before CurrentBottom has been initialized. However,
16161616
// BotRPTracker must have a valid position. We want the value live into the
16171617
// instruction or live out of the block, so ask for the previous
@@ -1638,7 +1638,7 @@ void ScheduleDAGMILive::updatePressureDiffs(ArrayRef<VRegMaskOrUnit> LiveUses) {
16381638
LI.Query(LIS->getInstructionIndex(*SU->getInstr()));
16391639
if (LRQ.valueIn() == VNI) {
16401640
PressureDiff &PDiff = getPressureDiff(SU);
1641-
PDiff.addPressureChange(Reg, true, &MRI);
1641+
PDiff.addPressureChange(VirtRegOrUnit(Reg), true, &MRI);
16421642
if (llvm::any_of(PDiff, [](const PressureChange &Change) {
16431643
return Change.isValid();
16441644
}))
@@ -1814,9 +1814,9 @@ unsigned ScheduleDAGMILive::computeCyclicCriticalPath() {
18141814
unsigned MaxCyclicLatency = 0;
18151815
// Visit each live out vreg def to find def/use pairs that cross iterations.
18161816
for (const VRegMaskOrUnit &P : RPTracker.getPressure().LiveOutRegs) {
1817-
Register Reg = P.RegUnit;
1818-
if (!Reg.isVirtual())
1817+
if (!P.VRegOrUnit.isVirtualReg())
18191818
continue;
1819+
Register Reg = P.VRegOrUnit.asVirtualReg();
18201820
const LiveInterval &LI = LIS->getInterval(Reg);
18211821
const VNInfo *DefVNI = LI.getVNInfoBefore(LIS->getMBBEndIdx(BB));
18221822
if (!DefVNI)

0 commit comments

Comments
 (0)