Skip to content

Commit 17a98f8

Browse files
authored
[RISCV] Optimize the spill/reload of segment registers (#153184)
The simplest way is: 1. Save `vtype` to a scalar register. 2. Insert a `vsetvli`. 3. Use segment load/store. 4. Restore `vtype` via `vsetvl`. But `vsetvl` is usually slow, so this PR is not in this way. Instead, we use wider whole load/store instructions if the register encoding is aligned. We have done the same optimization for COPY in llvm/llvm-project#84455. We found this suboptimal implementation when porting some video codec kernels via RVV intrinsics.
1 parent 2e74cc6 commit 17a98f8

File tree

8 files changed

+242
-407
lines changed

8 files changed

+242
-407
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ void RISCVInstrInfo::copyPhysRegVector(
382382
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
383383
const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
384384
const TargetRegisterClass *RegClass) const {
385-
const TargetRegisterInfo *TRI = STI.getRegisterInfo();
385+
const RISCVRegisterInfo *TRI = STI.getRegisterInfo();
386386
RISCVVType::VLMUL LMul = RISCVRI::getLMul(RegClass->TSFlags);
387387
unsigned NF = RISCVRI::getNF(RegClass->TSFlags);
388388

@@ -444,13 +444,7 @@ void RISCVInstrInfo::copyPhysRegVector(
444444
return {RISCVVType::LMUL_1, RISCV::VRRegClass, RISCV::VMV1R_V,
445445
RISCV::PseudoVMV_V_V_M1, RISCV::PseudoVMV_V_I_M1};
446446
};
447-
auto FindRegWithEncoding = [TRI](const TargetRegisterClass &RegClass,
448-
uint16_t Encoding) {
449-
MCRegister Reg = RISCV::V0 + Encoding;
450-
if (RISCVRI::getLMul(RegClass.TSFlags) == RISCVVType::LMUL_1)
451-
return Reg;
452-
return TRI->getMatchingSuperReg(Reg, RISCV::sub_vrm1_0, &RegClass);
453-
};
447+
454448
while (I != NumRegs) {
455449
// For non-segment copying, we only do this once as the registers are always
456450
// aligned.
@@ -470,9 +464,9 @@ void RISCVInstrInfo::copyPhysRegVector(
470464

471465
// Emit actual copying.
472466
// For reversed copying, the encoding should be decreased.
473-
MCRegister ActualSrcReg = FindRegWithEncoding(
467+
MCRegister ActualSrcReg = TRI->findVRegWithEncoding(
474468
RegClass, ReversedCopy ? (SrcEncoding - NumCopied + 1) : SrcEncoding);
475-
MCRegister ActualDstReg = FindRegWithEncoding(
469+
MCRegister ActualDstReg = TRI->findVRegWithEncoding(
476470
RegClass, ReversedCopy ? (DstEncoding - NumCopied + 1) : DstEncoding);
477471

478472
auto MIB = BuildMI(MBB, MBBI, DL, get(Opc), ActualDstReg);

llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp

Lines changed: 87 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,25 @@ void RISCVRegisterInfo::adjustReg(MachineBasicBlock &MBB,
389389
.setMIFlag(Flag);
390390
}
391391

392-
// Split a VSPILLx_Mx pseudo into multiple whole register stores separated by
393-
// LMUL*VLENB bytes.
394-
void RISCVRegisterInfo::lowerVSPILL(MachineBasicBlock::iterator II) const {
392+
static std::tuple<RISCVVType::VLMUL, const TargetRegisterClass &, unsigned>
393+
getSpillReloadInfo(unsigned NumRemaining, uint16_t RegEncoding, bool IsSpill) {
394+
if (NumRemaining >= 8 && RegEncoding % 8 == 0)
395+
return {RISCVVType::LMUL_8, RISCV::VRM8RegClass,
396+
IsSpill ? RISCV::VS8R_V : RISCV::VL8RE8_V};
397+
if (NumRemaining >= 4 && RegEncoding % 4 == 0)
398+
return {RISCVVType::LMUL_4, RISCV::VRM4RegClass,
399+
IsSpill ? RISCV::VS4R_V : RISCV::VL4RE8_V};
400+
if (NumRemaining >= 2 && RegEncoding % 2 == 0)
401+
return {RISCVVType::LMUL_2, RISCV::VRM2RegClass,
402+
IsSpill ? RISCV::VS2R_V : RISCV::VL2RE8_V};
403+
return {RISCVVType::LMUL_1, RISCV::VRRegClass,
404+
IsSpill ? RISCV::VS1R_V : RISCV::VL1RE8_V};
405+
}
406+
407+
// Split a VSPILLx_Mx/VSPILLx_Mx pseudo into multiple whole register stores
408+
// separated by LMUL*VLENB bytes.
409+
void RISCVRegisterInfo::lowerSegmentSpillReload(MachineBasicBlock::iterator II,
410+
bool IsSpill) const {
395411
DebugLoc DL = II->getDebugLoc();
396412
MachineBasicBlock &MBB = *II->getParent();
397413
MachineFunction &MF = *MBB.getParent();
@@ -403,148 +419,75 @@ void RISCVRegisterInfo::lowerVSPILL(MachineBasicBlock::iterator II) const {
403419
auto ZvlssegInfo = RISCV::isRVVSpillForZvlsseg(II->getOpcode());
404420
unsigned NF = ZvlssegInfo->first;
405421
unsigned LMUL = ZvlssegInfo->second;
406-
assert(NF * LMUL <= 8 && "Invalid NF/LMUL combinations.");
407-
unsigned Opcode, SubRegIdx;
408-
switch (LMUL) {
409-
default:
410-
llvm_unreachable("LMUL must be 1, 2, or 4.");
411-
case 1:
412-
Opcode = RISCV::VS1R_V;
413-
SubRegIdx = RISCV::sub_vrm1_0;
414-
break;
415-
case 2:
416-
Opcode = RISCV::VS2R_V;
417-
SubRegIdx = RISCV::sub_vrm2_0;
418-
break;
419-
case 4:
420-
Opcode = RISCV::VS4R_V;
421-
SubRegIdx = RISCV::sub_vrm4_0;
422-
break;
423-
}
424-
static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7,
425-
"Unexpected subreg numbering");
426-
static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3,
427-
"Unexpected subreg numbering");
428-
static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1,
429-
"Unexpected subreg numbering");
430-
431-
Register VL = MRI.createVirtualRegister(&RISCV::GPRRegClass);
432-
// Optimize for constant VLEN.
433-
if (auto VLEN = STI.getRealVLen()) {
434-
const int64_t VLENB = *VLEN / 8;
435-
int64_t Offset = VLENB * LMUL;
436-
STI.getInstrInfo()->movImm(MBB, II, DL, VL, Offset);
437-
} else {
438-
BuildMI(MBB, II, DL, TII->get(RISCV::PseudoReadVLENB), VL);
439-
uint32_t ShiftAmount = Log2_32(LMUL);
440-
if (ShiftAmount != 0)
441-
BuildMI(MBB, II, DL, TII->get(RISCV::SLLI), VL)
442-
.addReg(VL)
443-
.addImm(ShiftAmount);
444-
}
422+
unsigned NumRegs = NF * LMUL;
423+
assert(NumRegs <= 8 && "Invalid NF/LMUL combinations.");
445424

446-
Register SrcReg = II->getOperand(0).getReg();
425+
Register Reg = II->getOperand(0).getReg();
426+
uint16_t RegEncoding = TRI->getEncodingValue(Reg);
447427
Register Base = II->getOperand(1).getReg();
448428
bool IsBaseKill = II->getOperand(1).isKill();
449429
Register NewBase = MRI.createVirtualRegister(&RISCV::GPRRegClass);
450430

451431
auto *OldMMO = *(II->memoperands_begin());
452432
LocationSize OldLoc = OldMMO->getSize();
453433
assert(OldLoc.isPrecise() && OldLoc.getValue().isKnownMultipleOf(NF));
454-
TypeSize NewSize = OldLoc.getValue().divideCoefficientBy(NF);
455-
auto *NewMMO = MF.getMachineMemOperand(OldMMO, OldMMO->getOffset(), NewSize);
456-
for (unsigned I = 0; I < NF; ++I) {
457-
// Adding implicit-use of super register to describe we are using part of
458-
// super register, that prevents machine verifier complaining when part of
459-
// subreg is undef, see comment in MachineVerifier::checkLiveness for more
460-
// detail.
461-
BuildMI(MBB, II, DL, TII->get(Opcode))
462-
.addReg(TRI->getSubReg(SrcReg, SubRegIdx + I))
463-
.addReg(Base, getKillRegState(I == NF - 1))
464-
.addMemOperand(NewMMO)
465-
.addReg(SrcReg, RegState::Implicit);
466-
if (I != NF - 1)
434+
TypeSize VRegSize = OldLoc.getValue().divideCoefficientBy(NumRegs);
435+
436+
Register VLENB = 0;
437+
unsigned PreHandledNum = 0;
438+
unsigned I = 0;
439+
while (I != NumRegs) {
440+
auto [LMulHandled, RegClass, Opcode] =
441+
getSpillReloadInfo(NumRegs - I, RegEncoding, IsSpill);
442+
auto [RegNumHandled, _] = RISCVVType::decodeVLMUL(LMulHandled);
443+
bool IsLast = I + RegNumHandled == NumRegs;
444+
if (PreHandledNum) {
445+
Register Step;
446+
// Optimize for constant VLEN.
447+
if (auto VLEN = STI.getRealVLen()) {
448+
int64_t Offset = *VLEN / 8 * PreHandledNum;
449+
Step = MRI.createVirtualRegister(&RISCV::GPRRegClass);
450+
STI.getInstrInfo()->movImm(MBB, II, DL, Step, Offset);
451+
} else {
452+
if (!VLENB) {
453+
VLENB = MRI.createVirtualRegister(&RISCV::GPRRegClass);
454+
BuildMI(MBB, II, DL, TII->get(RISCV::PseudoReadVLENB), VLENB);
455+
}
456+
uint32_t ShiftAmount = Log2_32(PreHandledNum);
457+
if (ShiftAmount == 0)
458+
Step = VLENB;
459+
else {
460+
Step = MRI.createVirtualRegister(&RISCV::GPRRegClass);
461+
BuildMI(MBB, II, DL, TII->get(RISCV::SLLI), Step)
462+
.addReg(VLENB, getKillRegState(IsLast))
463+
.addImm(ShiftAmount);
464+
}
465+
}
466+
467467
BuildMI(MBB, II, DL, TII->get(RISCV::ADD), NewBase)
468468
.addReg(Base, getKillRegState(I != 0 || IsBaseKill))
469-
.addReg(VL, getKillRegState(I == NF - 2));
470-
Base = NewBase;
471-
}
472-
II->eraseFromParent();
473-
}
469+
.addReg(Step, getKillRegState(Step != VLENB || IsLast));
470+
Base = NewBase;
471+
}
474472

475-
// Split a VSPILLx_Mx pseudo into multiple whole register loads separated by
476-
// LMUL*VLENB bytes.
477-
void RISCVRegisterInfo::lowerVRELOAD(MachineBasicBlock::iterator II) const {
478-
DebugLoc DL = II->getDebugLoc();
479-
MachineBasicBlock &MBB = *II->getParent();
480-
MachineFunction &MF = *MBB.getParent();
481-
MachineRegisterInfo &MRI = MF.getRegInfo();
482-
const RISCVSubtarget &STI = MF.getSubtarget<RISCVSubtarget>();
483-
const TargetInstrInfo *TII = STI.getInstrInfo();
484-
const TargetRegisterInfo *TRI = STI.getRegisterInfo();
473+
MCRegister ActualReg = findVRegWithEncoding(RegClass, RegEncoding);
474+
MachineInstrBuilder MIB =
475+
BuildMI(MBB, II, DL, TII->get(Opcode))
476+
.addReg(ActualReg, getDefRegState(!IsSpill))
477+
.addReg(Base, getKillRegState(IsLast))
478+
.addMemOperand(MF.getMachineMemOperand(OldMMO, OldMMO->getOffset(),
479+
VRegSize * RegNumHandled));
485480

486-
auto ZvlssegInfo = RISCV::isRVVSpillForZvlsseg(II->getOpcode());
487-
unsigned NF = ZvlssegInfo->first;
488-
unsigned LMUL = ZvlssegInfo->second;
489-
assert(NF * LMUL <= 8 && "Invalid NF/LMUL combinations.");
490-
unsigned Opcode, SubRegIdx;
491-
switch (LMUL) {
492-
default:
493-
llvm_unreachable("LMUL must be 1, 2, or 4.");
494-
case 1:
495-
Opcode = RISCV::VL1RE8_V;
496-
SubRegIdx = RISCV::sub_vrm1_0;
497-
break;
498-
case 2:
499-
Opcode = RISCV::VL2RE8_V;
500-
SubRegIdx = RISCV::sub_vrm2_0;
501-
break;
502-
case 4:
503-
Opcode = RISCV::VL4RE8_V;
504-
SubRegIdx = RISCV::sub_vrm4_0;
505-
break;
506-
}
507-
static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7,
508-
"Unexpected subreg numbering");
509-
static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3,
510-
"Unexpected subreg numbering");
511-
static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1,
512-
"Unexpected subreg numbering");
513-
514-
Register VL = MRI.createVirtualRegister(&RISCV::GPRRegClass);
515-
// Optimize for constant VLEN.
516-
if (auto VLEN = STI.getRealVLen()) {
517-
const int64_t VLENB = *VLEN / 8;
518-
int64_t Offset = VLENB * LMUL;
519-
STI.getInstrInfo()->movImm(MBB, II, DL, VL, Offset);
520-
} else {
521-
BuildMI(MBB, II, DL, TII->get(RISCV::PseudoReadVLENB), VL);
522-
uint32_t ShiftAmount = Log2_32(LMUL);
523-
if (ShiftAmount != 0)
524-
BuildMI(MBB, II, DL, TII->get(RISCV::SLLI), VL)
525-
.addReg(VL)
526-
.addImm(ShiftAmount);
527-
}
481+
// Adding implicit-use of super register to describe we are using part of
482+
// super register, that prevents machine verifier complaining when part of
483+
// subreg is undef, see comment in MachineVerifier::checkLiveness for more
484+
// detail.
485+
if (IsSpill)
486+
MIB.addReg(Reg, RegState::Implicit);
528487

529-
Register DestReg = II->getOperand(0).getReg();
530-
Register Base = II->getOperand(1).getReg();
531-
bool IsBaseKill = II->getOperand(1).isKill();
532-
Register NewBase = MRI.createVirtualRegister(&RISCV::GPRRegClass);
533-
auto *OldMMO = *(II->memoperands_begin());
534-
LocationSize OldLoc = OldMMO->getSize();
535-
assert(OldLoc.isPrecise() && OldLoc.getValue().isKnownMultipleOf(NF));
536-
TypeSize NewSize = OldLoc.getValue().divideCoefficientBy(NF);
537-
auto *NewMMO = MF.getMachineMemOperand(OldMMO, OldMMO->getOffset(), NewSize);
538-
for (unsigned I = 0; I < NF; ++I) {
539-
BuildMI(MBB, II, DL, TII->get(Opcode),
540-
TRI->getSubReg(DestReg, SubRegIdx + I))
541-
.addReg(Base, getKillRegState(I == NF - 1))
542-
.addMemOperand(NewMMO);
543-
if (I != NF - 1)
544-
BuildMI(MBB, II, DL, TII->get(RISCV::ADD), NewBase)
545-
.addReg(Base, getKillRegState(I != 0 || IsBaseKill))
546-
.addReg(VL, getKillRegState(I == NF - 2));
547-
Base = NewBase;
488+
PreHandledNum = RegNumHandled;
489+
RegEncoding += RegNumHandled;
490+
I += RegNumHandled;
548491
}
549492
II->eraseFromParent();
550493
}
@@ -635,9 +578,7 @@ bool RISCVRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator II,
635578
}
636579

637580
// Handle spill/fill of synthetic register classes for segment operations to
638-
// ensure correctness in the edge case one gets spilled. There are many
639-
// possible optimizations here, but given the extreme rarity of such spills,
640-
// we prefer simplicity of implementation for now.
581+
// ensure correctness in the edge case one gets spilled.
641582
switch (MI.getOpcode()) {
642583
case RISCV::PseudoVSPILL2_M1:
643584
case RISCV::PseudoVSPILL2_M2:
@@ -650,7 +591,7 @@ bool RISCVRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator II,
650591
case RISCV::PseudoVSPILL6_M1:
651592
case RISCV::PseudoVSPILL7_M1:
652593
case RISCV::PseudoVSPILL8_M1:
653-
lowerVSPILL(II);
594+
lowerSegmentSpillReload(II, /*IsSpill=*/true);
654595
return true;
655596
case RISCV::PseudoVRELOAD2_M1:
656597
case RISCV::PseudoVRELOAD2_M2:
@@ -663,7 +604,7 @@ bool RISCVRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator II,
663604
case RISCV::PseudoVRELOAD6_M1:
664605
case RISCV::PseudoVRELOAD7_M1:
665606
case RISCV::PseudoVRELOAD8_M1:
666-
lowerVRELOAD(II);
607+
lowerSegmentSpillReload(II, /*IsSpill=*/false);
667608
return true;
668609
}
669610

@@ -1052,3 +993,12 @@ bool RISCVRegisterInfo::getRegAllocationHints(
1052993

1053994
return BaseImplRetVal;
1054995
}
996+
997+
Register
998+
RISCVRegisterInfo::findVRegWithEncoding(const TargetRegisterClass &RegClass,
999+
uint16_t Encoding) const {
1000+
MCRegister Reg = RISCV::V0 + Encoding;
1001+
if (RISCVRI::getLMul(RegClass.TSFlags) == RISCVVType::LMUL_1)
1002+
return Reg;
1003+
return getMatchingSuperReg(Reg, RISCV::sub_vrm1_0, &RegClass);
1004+
}

llvm/lib/Target/RISCV/RISCVRegisterInfo.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ struct RISCVRegisterInfo : public RISCVGenRegisterInfo {
107107
int64_t getFrameIndexInstrOffset(const MachineInstr *MI,
108108
int Idx) const override;
109109

110-
void lowerVSPILL(MachineBasicBlock::iterator II) const;
111-
void lowerVRELOAD(MachineBasicBlock::iterator II) const;
110+
void lowerSegmentSpillReload(MachineBasicBlock::iterator II,
111+
bool IsSpill) const;
112112

113113
Register getFrameRegister(const MachineFunction &MF) const override;
114114

@@ -144,6 +144,9 @@ struct RISCVRegisterInfo : public RISCVGenRegisterInfo {
144144
const MachineFunction &MF, const VirtRegMap *VRM,
145145
const LiveRegMatrix *Matrix) const override;
146146

147+
Register findVRegWithEncoding(const TargetRegisterClass &RegClass,
148+
uint16_t Encoding) const;
149+
147150
static bool isVRRegClass(const TargetRegisterClass *RC) {
148151
return RISCVRI::isVRegClass(RC->TSFlags) &&
149152
RISCVRI::getNF(RC->TSFlags) == 1;

llvm/test/CodeGen/RISCV/early-clobber-tied-def-subreg-liveness.ll

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,7 @@ define void @_Z3foov() {
4040
; CHECK-NEXT: addi a0, a0, %lo(.L__const._Z3foov.var_45)
4141
; CHECK-NEXT: vle16.v v12, (a0)
4242
; CHECK-NEXT: addi a0, sp, 16
43-
; CHECK-NEXT: csrr a1, vlenb
44-
; CHECK-NEXT: slli a1, a1, 1
45-
; CHECK-NEXT: vs2r.v v8, (a0) # vscale x 16-byte Folded Spill
46-
; CHECK-NEXT: add a0, a0, a1
47-
; CHECK-NEXT: vs2r.v v10, (a0) # vscale x 16-byte Folded Spill
48-
; CHECK-NEXT: add a0, a0, a1
49-
; CHECK-NEXT: vs2r.v v12, (a0) # vscale x 16-byte Folded Spill
50-
; CHECK-NEXT: add a0, a0, a1
51-
; CHECK-NEXT: vs2r.v v14, (a0) # vscale x 16-byte Folded Spill
43+
; CHECK-NEXT: vs8r.v v8, (a0) # vscale x 64-byte Folded Spill
5244
; CHECK-NEXT: lui a0, %hi(.L__const._Z3foov.var_40)
5345
; CHECK-NEXT: addi a0, a0, %lo(.L__const._Z3foov.var_40)
5446
; CHECK-NEXT: #APP
@@ -59,15 +51,7 @@ define void @_Z3foov() {
5951
; CHECK-NEXT: addi a0, a0, 928
6052
; CHECK-NEXT: vmsbc.vx v0, v8, a0
6153
; CHECK-NEXT: addi a0, sp, 16
62-
; CHECK-NEXT: csrr a1, vlenb
63-
; CHECK-NEXT: slli a1, a1, 1
64-
; CHECK-NEXT: vl2r.v v8, (a0) # vscale x 16-byte Folded Reload
65-
; CHECK-NEXT: add a0, a0, a1
66-
; CHECK-NEXT: vl2r.v v10, (a0) # vscale x 16-byte Folded Reload
67-
; CHECK-NEXT: add a0, a0, a1
68-
; CHECK-NEXT: vl2r.v v12, (a0) # vscale x 16-byte Folded Reload
69-
; CHECK-NEXT: add a0, a0, a1
70-
; CHECK-NEXT: vl2r.v v14, (a0) # vscale x 16-byte Folded Reload
54+
; CHECK-NEXT: vl8r.v v8, (a0) # vscale x 64-byte Folded Reload
7155
; CHECK-NEXT: csrr a0, vlenb
7256
; CHECK-NEXT: slli a0, a0, 3
7357
; CHECK-NEXT: add a0, sp, a0

0 commit comments

Comments
 (0)