Skip to content

Commit c8cc345

Browse files
committed
[RISCV][llvm] Handle vector callee saved register correctly
In TargetFrameLowering::determineCalleeSaves, any vector register is marked as saved if any of its subregister is clobbered, this is not correct in vector registers. We only want the vector register to be marked as saved only if all of its subregisters are clobbered. This patch handles vector callee saved registers in target hook.
1 parent 90f733c commit c8cc345

File tree

5 files changed

+466
-6035
lines changed

5 files changed

+466
-6035
lines changed

llvm/lib/Target/RISCV/RISCVCallingConv.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ def CSR_XLEN_F32_Interrupt: CalleeSavedRegs<(add CSR_Interrupt,
5656
def CSR_XLEN_F64_Interrupt: CalleeSavedRegs<(add CSR_Interrupt,
5757
(sequence "F%u_D", 0, 31))>;
5858

59+
defvar VREGS = (add (sequence "V%u", 0, 31),
60+
(sequence "V%uM2", 0, 31, 2),
61+
(sequence "V%uM4", 0, 31, 4),
62+
(sequence "V%uM8", 0, 31, 8));
63+
5964
// Same as CSR_Interrupt, but including all vector registers.
60-
def CSR_XLEN_V_Interrupt: CalleeSavedRegs<(add CSR_Interrupt,
61-
(sequence "V%u", 0, 31))>;
65+
def CSR_XLEN_V_Interrupt: CalleeSavedRegs<(add CSR_Interrupt, VREGS)>;
6266

6367
// Same as CSR_Interrupt, but including all 32-bit FP registers and all vector
6468
// registers.
65-
def CSR_XLEN_F32_V_Interrupt: CalleeSavedRegs<(add CSR_XLEN_F32_Interrupt,
66-
(sequence "V%u", 0, 31))>;
69+
def CSR_XLEN_F32_V_Interrupt: CalleeSavedRegs<(add CSR_XLEN_F32_Interrupt, VREGS)>;
6770

6871
// Same as CSR_Interrupt, but including all 64-bit FP registers and all vector
6972
// registers.
70-
def CSR_XLEN_F64_V_Interrupt: CalleeSavedRegs<(add CSR_XLEN_F64_Interrupt,
71-
(sequence "V%u", 0, 31))>;
73+
def CSR_XLEN_F64_V_Interrupt: CalleeSavedRegs<(add CSR_XLEN_F64_Interrupt, VREGS)>;
7274

7375
// Same as CSR_Interrupt, but excluding X16-X31.
7476
def CSR_Interrupt_RVE : CalleeSavedRegs<(sub CSR_Interrupt,

llvm/lib/Target/RISCV/RISCVFrameLowering.cpp

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,10 +1515,53 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI,
15151515
return Offset;
15161516
}
15171517

1518+
static MCRegister getRVVBaseRegister(const RISCVRegisterInfo &TRI,
1519+
const Register &Reg) {
1520+
MCRegister BaseReg = TRI.getSubReg(Reg, RISCV::sub_vrm1_0);
1521+
// If it's not a grouped vector register, it doesn't have subregister, so
1522+
// the base register is just itself.
1523+
if (BaseReg == RISCV::NoRegister)
1524+
BaseReg = Reg;
1525+
return BaseReg;
1526+
}
1527+
15181528
void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
15191529
BitVector &SavedRegs,
15201530
RegScavenger *RS) const {
15211531
TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
1532+
1533+
// In TargetFrameLowering::determineCalleeSaves, any vector register is marked
1534+
// as saved if any of its subregister is clobbered, this is not correct in
1535+
// vector registers. We only want the vector register to be marked as saved
1536+
// only if all of its subregisters are clobbered.
1537+
// For example:
1538+
// Original behavior: If v24 is marked, v24m2, v24m4, v24m8 are also marked.
1539+
// Correct behavior: v24m2 is marked only if v24 and v25 are marked.
1540+
const MachineRegisterInfo &MRI = MF.getRegInfo();
1541+
const MCPhysReg *CSRegs = MRI.getCalleeSavedRegs();
1542+
const RISCVRegisterInfo &TRI = *STI.getRegisterInfo();
1543+
for (unsigned i = 0; CSRegs[i]; ++i) {
1544+
unsigned CSReg = CSRegs[i];
1545+
// Only vector registers need special care.
1546+
if (!RISCV::VRRegClass.contains(getRVVBaseRegister(TRI, CSReg)))
1547+
continue;
1548+
1549+
SavedRegs.reset(CSReg);
1550+
1551+
auto SubRegs = TRI.subregs(CSReg);
1552+
// Set the register and it's all subregisters.
1553+
if (!MRI.def_empty(CSReg) || MRI.getUsedPhysRegsMask().test(CSReg)) {
1554+
SavedRegs.set(CSReg);
1555+
llvm::for_each(SubRegs, [&](unsigned Reg) { return SavedRegs.set(Reg); });
1556+
}
1557+
1558+
// Combine to super register if all of its subregisters are marked.
1559+
if (!SubRegs.empty() && llvm::all_of(SubRegs, [&](unsigned Reg) {
1560+
return SavedRegs.test(Reg);
1561+
}))
1562+
SavedRegs.set(CSReg);
1563+
}
1564+
15221565
// Unconditionally spill RA and FP only if the function uses a frame
15231566
// pointer.
15241567
if (hasFP(MF)) {
@@ -2107,16 +2150,6 @@ static unsigned getCalleeSavedRVVNumRegs(const Register &BaseReg) {
21072150
: 8;
21082151
}
21092152

2110-
static MCRegister getRVVBaseRegister(const RISCVRegisterInfo &TRI,
2111-
const Register &Reg) {
2112-
MCRegister BaseReg = TRI.getSubReg(Reg, RISCV::sub_vrm1_0);
2113-
// If it's not a grouped vector register, it doesn't have subregister, so
2114-
// the base register is just itself.
2115-
if (BaseReg == RISCV::NoRegister)
2116-
BaseReg = Reg;
2117-
return BaseReg;
2118-
}
2119-
21202153
void RISCVFrameLowering::emitCalleeSavedRVVPrologCFI(
21212154
MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, bool HasFP) const {
21222155
MachineFunction *MF = MBB.getParent();

0 commit comments

Comments
 (0)