Skip to content

Commit e812231

Browse files
committed
[RDF] RegisterRef/RegisterId improvements. NFC
RegisterId can represent a physical register, a MCRegUnit, or an index into a side structure that stores register masks. These 3 types were encoded by using the physical reg, stack slot, and virtual register encoding partitions from the Register class. This encoding scheme alias wasn't well contained so Register::index2StackSlot and Register::stackSlotIndex appeared in multiple places. This patch gives RegisterRef its own encoding defines and separates it from Register. I've removed the generic idx() method in favor of getAsMCReg(), getAsMCRegUnit(), and getMaskIdx() for some degree of type safety. Some places used the RegisterId field of RegisterRef directly as a register. Those have been updated to use getAsMCReg. Some special cases for RegisterId 0 have been removed as it can be treated like a MCRegister by existing code. I think I want to rename the Reg field of RegisterRef to Id, but I'll do that in another patch. Additionally, callers of the RegisterRef constructor need to be audited for implicit conversions from Register/MCRegister to unsigned.
1 parent 80ae168 commit e812231

File tree

3 files changed

+53
-45
lines changed

3 files changed

+53
-45
lines changed

llvm/include/llvm/CodeGen/RDFRegisters.h

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ template <typename T, unsigned N = 32> struct IndexedSet {
8686
};
8787

8888
struct RegisterRef {
89+
private:
90+
static constexpr unsigned MaskFlag = 1u << 30;
91+
static constexpr unsigned UnitFlag = 1u << 31;
92+
93+
public:
8994
RegisterId Reg = 0;
9095
LaneBitmask Mask = LaneBitmask::getNone(); // Only for registers.
9196

@@ -99,7 +104,20 @@ struct RegisterRef {
99104
constexpr bool isUnit() const { return isUnitId(Reg); }
100105
constexpr bool isMask() const { return isMaskId(Reg); }
101106

102-
constexpr unsigned idx() const { return toIdx(Reg); }
107+
constexpr MCRegister asMCReg() const {
108+
assert(isReg());
109+
return Reg;
110+
}
111+
112+
constexpr MCRegUnit asMCRegUnit() const {
113+
assert(isUnit());
114+
return Reg & ~UnitFlag;
115+
}
116+
117+
constexpr unsigned getMaskIdx() const {
118+
assert(isMask());
119+
return toMaskIdx(Reg);
120+
}
103121

104122
constexpr operator bool() const {
105123
return !isReg() || (Reg != 0 && Mask.any());
@@ -110,25 +128,19 @@ struct RegisterRef {
110128
std::hash<LaneBitmask::Type>{}(Mask.getAsInteger());
111129
}
112130

113-
static constexpr bool isRegId(unsigned Id) {
114-
return Register::isPhysicalRegister(Id);
115-
}
116-
static constexpr bool isUnitId(unsigned Id) {
117-
return Register::isVirtualRegister(Id);
131+
static constexpr bool isRegId(RegisterId Id) {
132+
return !(Id & UnitFlag) && !(Id & MaskFlag);
118133
}
119-
static constexpr bool isMaskId(unsigned Id) { return Register(Id).isStack(); }
134+
static constexpr bool isUnitId(RegisterId Id) { return Id & UnitFlag; }
135+
static constexpr bool isMaskId(RegisterId Id) { return Id & MaskFlag; }
120136

121-
static constexpr RegisterId toUnitId(unsigned Idx) {
122-
return Idx | Register::VirtualRegFlag;
123-
}
137+
static constexpr RegisterId toUnitId(unsigned Idx) { return Idx | UnitFlag; }
138+
139+
static constexpr RegisterId toMaskId(unsigned Idx) { return Idx | MaskFlag; }
124140

125-
static constexpr unsigned toIdx(RegisterId Id) {
126-
// Not using virtReg2Index or stackSlot2Index, because they are
127-
// not constexpr.
128-
if (isUnitId(Id))
129-
return Id & ~Register::VirtualRegFlag;
130-
// RegId and MaskId are unchanged.
131-
return Id;
141+
static constexpr unsigned toMaskIdx(RegisterId Id) {
142+
assert(isMaskId(Id));
143+
return Id & ~MaskFlag;
132144
}
133145

134146
bool operator<(RegisterRef) const = delete;
@@ -141,11 +153,11 @@ struct PhysicalRegisterInfo {
141153
const MachineFunction &mf);
142154

143155
RegisterId getRegMaskId(const uint32_t *RM) const {
144-
return Register::index2StackSlot(RegMasks.find(RM));
156+
return RegisterRef::toMaskId(RegMasks.find(RM));
145157
}
146158

147159
const uint32_t *getRegMaskBits(RegisterId R) const {
148-
return RegMasks.get(Register(R).stackSlotIndex());
160+
return RegMasks.get(RegisterRef::toMaskIdx(R));
149161
}
150162

151163
bool alias(RegisterRef RA, RegisterRef RB) const;
@@ -158,7 +170,7 @@ struct PhysicalRegisterInfo {
158170
}
159171

160172
const BitVector &getMaskUnits(RegisterId MaskId) const {
161-
return MaskInfos[Register(MaskId).stackSlotIndex()].Units;
173+
return MaskInfos[RegisterRef::toMaskIdx(MaskId)].Units;
162174
}
163175

164176
std::set<RegisterId> getUnits(RegisterRef RR) const;
@@ -167,7 +179,7 @@ struct PhysicalRegisterInfo {
167179
return AliasInfos[U].Regs;
168180
}
169181

170-
RegisterRef mapTo(RegisterRef RR, unsigned R) const;
182+
RegisterRef mapTo(RegisterRef RR, RegisterId R) const;
171183
const TargetRegisterInfo &getTRI() const { return TRI; }
172184

173185
bool equal_to(RegisterRef A, RegisterRef B) const;

llvm/lib/CodeGen/RDFGraph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1827,7 +1827,7 @@ bool DataFlowGraph::hasUntrackedRef(Stmt S, bool IgnoreReserved) const {
18271827
for (Ref R : S.Addr->members(*this)) {
18281828
Ops.push_back(&R.Addr->getOp());
18291829
RegisterRef RR = R.Addr->getRegRef(*this);
1830-
if (IgnoreReserved && RR.isReg() && ReservedRegs[RR.idx()])
1830+
if (IgnoreReserved && RR.isReg() && ReservedRegs[RR.asMCReg().id()])
18311831
continue;
18321832
if (!isTracked(RR))
18331833
return true;

llvm/lib/CodeGen/RDFRegisters.cpp

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,10 @@ std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
126126
std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
127127
std::set<RegisterId> Units;
128128

129-
if (RR.Reg == 0)
130-
return Units; // Empty
131-
132129
if (RR.isReg()) {
133130
if (RR.Mask.none())
134131
return Units; // Empty
135-
for (MCRegUnitMaskIterator UM(RR.idx(), &TRI); UM.isValid(); ++UM) {
132+
for (MCRegUnitMaskIterator UM(RR.asMCReg(), &TRI); UM.isValid(); ++UM) {
136133
auto [U, M] = *UM;
137134
if ((M & RR.Mask).any())
138135
Units.insert(U);
@@ -142,7 +139,7 @@ std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
142139

143140
assert(RR.isMask());
144141
unsigned NumRegs = TRI.getNumRegs();
145-
const uint32_t *MB = getRegMaskBits(RR.idx());
142+
const uint32_t *MB = getRegMaskBits(RR.Reg);
146143
for (unsigned I = 0, E = (NumRegs + 31) / 32; I != E; ++I) {
147144
uint32_t C = ~MB[I]; // Clobbered regs
148145
if (I == 0) // Reg 0 should be ignored
@@ -162,12 +159,12 @@ std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
162159
return Units;
163160
}
164161

165-
RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, unsigned R) const {
162+
RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, RegisterId R) const {
166163
if (RR.Reg == R)
167164
return RR;
168-
if (unsigned Idx = TRI.getSubRegIndex(R, RR.Reg))
165+
if (unsigned Idx = TRI.getSubRegIndex(RegisterRef(R).asMCReg(), RR.asMCReg()))
169166
return RegisterRef(R, TRI.composeSubRegIndexLaneMask(Idx, RR.Mask));
170-
if (unsigned Idx = TRI.getSubRegIndex(RR.Reg, R)) {
167+
if (unsigned Idx = TRI.getSubRegIndex(RR.asMCReg(), RegisterRef(R).asMCReg())) {
171168
const RegInfo &RI = RegInfos[R];
172169
LaneBitmask RCM =
173170
RI.RegClass ? RI.RegClass->LaneMask : LaneBitmask::getAll();
@@ -187,8 +184,8 @@ bool PhysicalRegisterInfo::equal_to(RegisterRef A, RegisterRef B) const {
187184
return A.Mask == B.Mask;
188185

189186
// Compare reg units lexicographically.
190-
MCRegUnitMaskIterator AI(A.Reg, &getTRI());
191-
MCRegUnitMaskIterator BI(B.Reg, &getTRI());
187+
MCRegUnitMaskIterator AI(A.asMCReg(), &getTRI());
188+
MCRegUnitMaskIterator BI(B.asMCReg(), &getTRI());
192189
while (AI.isValid() && BI.isValid()) {
193190
auto [AReg, AMask] = *AI;
194191
auto [BReg, BMask] = *BI;
@@ -225,8 +222,8 @@ bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
225222
return A.Reg < B.Reg;
226223

227224
// Compare reg units lexicographically.
228-
llvm::MCRegUnitMaskIterator AI(A.Reg, &getTRI());
229-
llvm::MCRegUnitMaskIterator BI(B.Reg, &getTRI());
225+
llvm::MCRegUnitMaskIterator AI(A.asMCReg(), &getTRI());
226+
llvm::MCRegUnitMaskIterator BI(B.asMCReg(), &getTRI());
230227
while (AI.isValid() && BI.isValid()) {
231228
auto [AReg, AMask] = *AI;
232229
auto [BReg, BMask] = *BI;
@@ -252,18 +249,17 @@ bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
252249
}
253250

254251
void PhysicalRegisterInfo::print(raw_ostream &OS, RegisterRef A) const {
255-
if (A.Reg == 0 || A.isReg()) {
256-
if (0 < A.idx() && A.idx() < TRI.getNumRegs())
257-
OS << TRI.getName(A.idx());
252+
if (A.isReg()) {
253+
MCRegister Reg = A.asMCReg();
254+
if (Reg && Reg.id() < TRI.getNumRegs())
255+
OS << TRI.getName(Reg);
258256
else
259-
OS << printReg(A.idx(), &TRI);
257+
OS << printReg(Reg, &TRI);
260258
OS << PrintLaneMaskShort(A.Mask);
261259
} else if (A.isUnit()) {
262-
OS << printRegUnit(A.idx(), &TRI);
260+
OS << printRegUnit(A.asMCRegUnit(), &TRI);
263261
} else {
264-
assert(A.isMask());
265-
// RegMask SS flag is preserved by idx().
266-
unsigned Idx = Register(A.idx()).stackSlotIndex();
262+
unsigned Idx = A.getMaskIdx();
267263
const char *Fmt = Idx < 0x10000 ? "%04x" : "%08x";
268264
OS << "M#" << format(Fmt, Idx);
269265
}
@@ -280,7 +276,7 @@ bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
280276
if (RR.isMask())
281277
return Units.anyCommon(PRI.getMaskUnits(RR.Reg));
282278

283-
for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
279+
for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
284280
auto [Unit, LaneMask] = *U;
285281
if ((LaneMask & RR.Mask).any())
286282
if (Units.test(Unit))
@@ -295,7 +291,7 @@ bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
295291
return T.reset(Units).none();
296292
}
297293

298-
for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
294+
for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
299295
auto [Unit, LaneMask] = *U;
300296
if ((LaneMask & RR.Mask).any())
301297
if (!Units.test(Unit))
@@ -310,7 +306,7 @@ RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
310306
return *this;
311307
}
312308

313-
for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
309+
for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
314310
auto [Unit, LaneMask] = *U;
315311
if ((LaneMask & RR.Mask).any())
316312
Units.set(Unit);

0 commit comments

Comments
 (0)