Skip to content

Commit 7108b12

Browse files
authored
[RDF] RegisterRef/RegisterId improvements. NFC (#168030)
RegisterId can represent a physical register, a MCRegUnit, or an index into a side structure that stores register masks. These 3 types were encoded by using the physical reg, stack slot, and virtual register encoding partitions from the Register class. This encoding scheme alias wasn't well contained so Register::index2StackSlot and Register::stackSlotIndex appeared in multiple places. This patch gives RegisterRef its own encoding defines and separates it from Register. I've removed the generic idx() method in favor of getAsMCReg(), getAsMCRegUnit(), and getMaskIdx() for some degree of type safety. Some places used the RegisterId field of RegisterRef directly as a register. Those have been updated to use getAsMCReg. Some special cases for RegisterId 0 have been removed as it can be treated like a MCRegister by existing code. I think I want to rename the Reg field of RegisterRef to Id, but I'll do that in another patch. Additionally, callers of the RegisterRef constructor need to be audited for implicit conversions from Register/MCRegister to unsigned.
1 parent 8b105cb commit 7108b12

File tree

3 files changed

+50
-46
lines changed

3 files changed

+50
-46
lines changed

llvm/include/llvm/CodeGen/RDFRegisters.h

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ template <typename T, unsigned N = 32> struct IndexedSet {
8686
};
8787

8888
struct RegisterRef {
89+
private:
90+
static constexpr RegisterId MaskFlag = 1u << 30;
91+
static constexpr RegisterId UnitFlag = 1u << 31;
92+
93+
public:
8994
RegisterId Reg = 0;
9095
LaneBitmask Mask = LaneBitmask::getNone(); // Only for registers.
9196

@@ -99,7 +104,20 @@ struct RegisterRef {
99104
constexpr bool isUnit() const { return isUnitId(Reg); }
100105
constexpr bool isMask() const { return isMaskId(Reg); }
101106

102-
constexpr unsigned idx() const { return toIdx(Reg); }
107+
constexpr MCRegister asMCReg() const {
108+
assert(isReg());
109+
return Reg;
110+
}
111+
112+
constexpr MCRegUnit asMCRegUnit() const {
113+
assert(isUnit());
114+
return Reg & ~UnitFlag;
115+
}
116+
117+
constexpr unsigned asMaskIdx() const {
118+
assert(isMask());
119+
return Reg & ~MaskFlag;
120+
}
103121

104122
constexpr operator bool() const {
105123
return !isReg() || (Reg != 0 && Mask.any());
@@ -110,26 +128,15 @@ struct RegisterRef {
110128
std::hash<LaneBitmask::Type>{}(Mask.getAsInteger());
111129
}
112130

113-
static constexpr bool isRegId(unsigned Id) {
114-
return Register::isPhysicalRegister(Id);
115-
}
116-
static constexpr bool isUnitId(unsigned Id) {
117-
return Register::isVirtualRegister(Id);
131+
static constexpr bool isRegId(RegisterId Id) {
132+
return !(Id & UnitFlag) && !(Id & MaskFlag);
118133
}
119-
static constexpr bool isMaskId(unsigned Id) { return Register(Id).isStack(); }
134+
static constexpr bool isUnitId(RegisterId Id) { return Id & UnitFlag; }
135+
static constexpr bool isMaskId(RegisterId Id) { return Id & MaskFlag; }
120136

121-
static constexpr RegisterId toUnitId(unsigned Idx) {
122-
return Idx | Register::VirtualRegFlag;
123-
}
137+
static constexpr RegisterId toUnitId(unsigned Idx) { return Idx | UnitFlag; }
124138

125-
static constexpr unsigned toIdx(RegisterId Id) {
126-
// Not using virtReg2Index or stackSlot2Index, because they are
127-
// not constexpr.
128-
if (isUnitId(Id))
129-
return Id & ~Register::VirtualRegFlag;
130-
// RegId and MaskId are unchanged.
131-
return Id;
132-
}
139+
static constexpr RegisterId toMaskId(unsigned Idx) { return Idx | MaskFlag; }
133140

134141
bool operator<(RegisterRef) const = delete;
135142
bool operator==(RegisterRef) const = delete;
@@ -141,11 +148,11 @@ struct PhysicalRegisterInfo {
141148
const MachineFunction &mf);
142149

143150
RegisterId getRegMaskId(const uint32_t *RM) const {
144-
return Register::index2StackSlot(RegMasks.find(RM));
151+
return RegisterRef::toMaskId(RegMasks.find(RM));
145152
}
146153

147154
const uint32_t *getRegMaskBits(RegisterId R) const {
148-
return RegMasks.get(Register(R).stackSlotIndex());
155+
return RegMasks.get(RegisterRef(R).asMaskIdx());
149156
}
150157

151158
bool alias(RegisterRef RA, RegisterRef RB) const;
@@ -158,7 +165,7 @@ struct PhysicalRegisterInfo {
158165
}
159166

160167
const BitVector &getMaskUnits(RegisterId MaskId) const {
161-
return MaskInfos[Register(MaskId).stackSlotIndex()].Units;
168+
return MaskInfos[RegisterRef(MaskId).asMaskIdx()].Units;
162169
}
163170

164171
std::set<RegisterId> getUnits(RegisterRef RR) const;
@@ -167,7 +174,7 @@ struct PhysicalRegisterInfo {
167174
return AliasInfos[U].Regs;
168175
}
169176

170-
RegisterRef mapTo(RegisterRef RR, unsigned R) const;
177+
RegisterRef mapTo(RegisterRef RR, RegisterId R) const;
171178
const TargetRegisterInfo &getTRI() const { return TRI; }
172179

173180
bool equal_to(RegisterRef A, RegisterRef B) const;

llvm/lib/CodeGen/RDFGraph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1827,7 +1827,7 @@ bool DataFlowGraph::hasUntrackedRef(Stmt S, bool IgnoreReserved) const {
18271827
for (Ref R : S.Addr->members(*this)) {
18281828
Ops.push_back(&R.Addr->getOp());
18291829
RegisterRef RR = R.Addr->getRegRef(*this);
1830-
if (IgnoreReserved && RR.isReg() && ReservedRegs[RR.idx()])
1830+
if (IgnoreReserved && RR.isReg() && ReservedRegs[RR.asMCReg().id()])
18311831
continue;
18321832
if (!isTracked(RR))
18331833
return true;

llvm/lib/CodeGen/RDFRegisters.cpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,10 @@ std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
126126
std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
127127
std::set<RegisterId> Units;
128128

129-
if (RR.Reg == 0)
130-
return Units; // Empty
131-
132129
if (RR.isReg()) {
133130
if (RR.Mask.none())
134131
return Units; // Empty
135-
for (MCRegUnitMaskIterator UM(RR.idx(), &TRI); UM.isValid(); ++UM) {
132+
for (MCRegUnitMaskIterator UM(RR.asMCReg(), &TRI); UM.isValid(); ++UM) {
136133
auto [U, M] = *UM;
137134
if ((M & RR.Mask).any())
138135
Units.insert(U);
@@ -142,7 +139,7 @@ std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
142139

143140
assert(RR.isMask());
144141
unsigned NumRegs = TRI.getNumRegs();
145-
const uint32_t *MB = getRegMaskBits(RR.idx());
142+
const uint32_t *MB = getRegMaskBits(RR.Reg);
146143
for (unsigned I = 0, E = (NumRegs + 31) / 32; I != E; ++I) {
147144
uint32_t C = ~MB[I]; // Clobbered regs
148145
if (I == 0) // Reg 0 should be ignored
@@ -162,12 +159,13 @@ std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
162159
return Units;
163160
}
164161

165-
RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, unsigned R) const {
162+
RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, RegisterId R) const {
166163
if (RR.Reg == R)
167164
return RR;
168-
if (unsigned Idx = TRI.getSubRegIndex(R, RR.Reg))
165+
if (unsigned Idx = TRI.getSubRegIndex(RegisterRef(R).asMCReg(), RR.asMCReg()))
169166
return RegisterRef(R, TRI.composeSubRegIndexLaneMask(Idx, RR.Mask));
170-
if (unsigned Idx = TRI.getSubRegIndex(RR.Reg, R)) {
167+
if (unsigned Idx =
168+
TRI.getSubRegIndex(RR.asMCReg(), RegisterRef(R).asMCReg())) {
171169
const RegInfo &RI = RegInfos[R];
172170
LaneBitmask RCM =
173171
RI.RegClass ? RI.RegClass->LaneMask : LaneBitmask::getAll();
@@ -187,8 +185,8 @@ bool PhysicalRegisterInfo::equal_to(RegisterRef A, RegisterRef B) const {
187185
return A.Mask == B.Mask;
188186

189187
// Compare reg units lexicographically.
190-
MCRegUnitMaskIterator AI(A.Reg, &getTRI());
191-
MCRegUnitMaskIterator BI(B.Reg, &getTRI());
188+
MCRegUnitMaskIterator AI(A.asMCReg(), &getTRI());
189+
MCRegUnitMaskIterator BI(B.asMCReg(), &getTRI());
192190
while (AI.isValid() && BI.isValid()) {
193191
auto [AReg, AMask] = *AI;
194192
auto [BReg, BMask] = *BI;
@@ -225,8 +223,8 @@ bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
225223
return A.Reg < B.Reg;
226224

227225
// Compare reg units lexicographically.
228-
llvm::MCRegUnitMaskIterator AI(A.Reg, &getTRI());
229-
llvm::MCRegUnitMaskIterator BI(B.Reg, &getTRI());
226+
llvm::MCRegUnitMaskIterator AI(A.asMCReg(), &getTRI());
227+
llvm::MCRegUnitMaskIterator BI(B.asMCReg(), &getTRI());
230228
while (AI.isValid() && BI.isValid()) {
231229
auto [AReg, AMask] = *AI;
232230
auto [BReg, BMask] = *BI;
@@ -252,18 +250,17 @@ bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
252250
}
253251

254252
void PhysicalRegisterInfo::print(raw_ostream &OS, RegisterRef A) const {
255-
if (A.Reg == 0 || A.isReg()) {
256-
if (0 < A.idx() && A.idx() < TRI.getNumRegs())
257-
OS << TRI.getName(A.idx());
253+
if (A.isReg()) {
254+
MCRegister Reg = A.asMCReg();
255+
if (Reg && Reg.id() < TRI.getNumRegs())
256+
OS << TRI.getName(Reg);
258257
else
259-
OS << printReg(A.idx(), &TRI);
258+
OS << printReg(Reg, &TRI);
260259
OS << PrintLaneMaskShort(A.Mask);
261260
} else if (A.isUnit()) {
262-
OS << printRegUnit(A.idx(), &TRI);
261+
OS << printRegUnit(A.asMCRegUnit(), &TRI);
263262
} else {
264-
assert(A.isMask());
265-
// RegMask SS flag is preserved by idx().
266-
unsigned Idx = Register(A.idx()).stackSlotIndex();
263+
unsigned Idx = A.asMaskIdx();
267264
const char *Fmt = Idx < 0x10000 ? "%04x" : "%08x";
268265
OS << "M#" << format(Fmt, Idx);
269266
}
@@ -280,7 +277,7 @@ bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
280277
if (RR.isMask())
281278
return Units.anyCommon(PRI.getMaskUnits(RR.Reg));
282279

283-
for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
280+
for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
284281
auto [Unit, LaneMask] = *U;
285282
if ((LaneMask & RR.Mask).any())
286283
if (Units.test(Unit))
@@ -295,7 +292,7 @@ bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
295292
return T.reset(Units).none();
296293
}
297294

298-
for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
295+
for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
299296
auto [Unit, LaneMask] = *U;
300297
if ((LaneMask & RR.Mask).any())
301298
if (!Units.test(Unit))
@@ -310,7 +307,7 @@ RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
310307
return *this;
311308
}
312309

313-
for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
310+
for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
314311
auto [Unit, LaneMask] = *U;
315312
if ((LaneMask & RR.Mask).any())
316313
Units.set(Unit);

0 commit comments

Comments
 (0)