Skip to content

Commit 5ba9a78

Browse files
4vtomatmahesh-attarde
authored andcommitted
[RISCV][llvm] Handle vector callee saved register correctly (llvm#149467)
In TargetFrameLowering::determineCalleeSaves, any vector register is marked as saved if any of its subregister is clobbered, this is not correct in vector registers. We only want the vector register to be marked as saved only if all of its subregisters are clobbered. This patch handles vector callee saved registers in target hook.
1 parent 7551ab7 commit 5ba9a78

File tree

5 files changed

+466
-6035
lines changed

5 files changed

+466
-6035
lines changed

llvm/lib/Target/RISCV/RISCVCallingConv.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ def CSR_XLEN_F32_Interrupt: CalleeSavedRegs<(add CSR_Interrupt,
5656
def CSR_XLEN_F64_Interrupt: CalleeSavedRegs<(add CSR_Interrupt,
5757
(sequence "F%u_D", 0, 31))>;
5858

59+
defvar VREGS = (add (sequence "V%u", 0, 31),
60+
(sequence "V%uM2", 0, 31, 2),
61+
(sequence "V%uM4", 0, 31, 4),
62+
(sequence "V%uM8", 0, 31, 8));
63+
5964
// Same as CSR_Interrupt, but including all vector registers.
60-
def CSR_XLEN_V_Interrupt: CalleeSavedRegs<(add CSR_Interrupt,
61-
(sequence "V%u", 0, 31))>;
65+
def CSR_XLEN_V_Interrupt: CalleeSavedRegs<(add CSR_Interrupt, VREGS)>;
6266

6367
// Same as CSR_Interrupt, but including all 32-bit FP registers and all vector
6468
// registers.
65-
def CSR_XLEN_F32_V_Interrupt: CalleeSavedRegs<(add CSR_XLEN_F32_Interrupt,
66-
(sequence "V%u", 0, 31))>;
69+
def CSR_XLEN_F32_V_Interrupt: CalleeSavedRegs<(add CSR_XLEN_F32_Interrupt, VREGS)>;
6770

6871
// Same as CSR_Interrupt, but including all 64-bit FP registers and all vector
6972
// registers.
70-
def CSR_XLEN_F64_V_Interrupt: CalleeSavedRegs<(add CSR_XLEN_F64_Interrupt,
71-
(sequence "V%u", 0, 31))>;
73+
def CSR_XLEN_F64_V_Interrupt: CalleeSavedRegs<(add CSR_XLEN_F64_Interrupt, VREGS)>;
7274

7375
// Same as CSR_Interrupt, but excluding X16-X31.
7476
def CSR_Interrupt_RVE : CalleeSavedRegs<(sub CSR_Interrupt,

llvm/lib/Target/RISCV/RISCVFrameLowering.cpp

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,10 +1544,53 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI,
15441544
return Offset;
15451545
}
15461546

1547+
static MCRegister getRVVBaseRegister(const RISCVRegisterInfo &TRI,
1548+
const Register &Reg) {
1549+
MCRegister BaseReg = TRI.getSubReg(Reg, RISCV::sub_vrm1_0);
1550+
// If it's not a grouped vector register, it doesn't have subregister, so
1551+
// the base register is just itself.
1552+
if (BaseReg == RISCV::NoRegister)
1553+
BaseReg = Reg;
1554+
return BaseReg;
1555+
}
1556+
15471557
void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
15481558
BitVector &SavedRegs,
15491559
RegScavenger *RS) const {
15501560
TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
1561+
1562+
// In TargetFrameLowering::determineCalleeSaves, any vector register is marked
1563+
// as saved if any of its subregister is clobbered, this is not correct in
1564+
// vector registers. We only want the vector register to be marked as saved
1565+
// if all of its subregisters are clobbered.
1566+
// For example:
1567+
// Original behavior: If v24 is marked, v24m2, v24m4, v24m8 are also marked.
1568+
// Correct behavior: v24m2 is marked only if v24 and v25 are marked.
1569+
const MachineRegisterInfo &MRI = MF.getRegInfo();
1570+
const MCPhysReg *CSRegs = MRI.getCalleeSavedRegs();
1571+
const RISCVRegisterInfo &TRI = *STI.getRegisterInfo();
1572+
for (unsigned i = 0; CSRegs[i]; ++i) {
1573+
unsigned CSReg = CSRegs[i];
1574+
// Only vector registers need special care.
1575+
if (!RISCV::VRRegClass.contains(getRVVBaseRegister(TRI, CSReg)))
1576+
continue;
1577+
1578+
SavedRegs.reset(CSReg);
1579+
1580+
auto SubRegs = TRI.subregs(CSReg);
1581+
// Set the register and all its subregisters.
1582+
if (!MRI.def_empty(CSReg) || MRI.getUsedPhysRegsMask().test(CSReg)) {
1583+
SavedRegs.set(CSReg);
1584+
llvm::for_each(SubRegs, [&](unsigned Reg) { return SavedRegs.set(Reg); });
1585+
}
1586+
1587+
// Combine to super register if all of its subregisters are marked.
1588+
if (!SubRegs.empty() && llvm::all_of(SubRegs, [&](unsigned Reg) {
1589+
return SavedRegs.test(Reg);
1590+
}))
1591+
SavedRegs.set(CSReg);
1592+
}
1593+
15511594
// Unconditionally spill RA and FP only if the function uses a frame
15521595
// pointer.
15531596
if (hasFP(MF)) {
@@ -2137,16 +2180,6 @@ static unsigned getCalleeSavedRVVNumRegs(const Register &BaseReg) {
21372180
: 8;
21382181
}
21392182

2140-
static MCRegister getRVVBaseRegister(const RISCVRegisterInfo &TRI,
2141-
const Register &Reg) {
2142-
MCRegister BaseReg = TRI.getSubReg(Reg, RISCV::sub_vrm1_0);
2143-
// If it's not a grouped vector register, it doesn't have subregister, so
2144-
// the base register is just itself.
2145-
if (BaseReg == RISCV::NoRegister)
2146-
BaseReg = Reg;
2147-
return BaseReg;
2148-
}
2149-
21502183
void RISCVFrameLowering::emitCalleeSavedRVVPrologCFI(
21512184
MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, bool HasFP) const {
21522185
MachineFunction *MF = MBB.getParent();

0 commit comments

Comments
 (0)