Skip to content

Commit ea2b861

Browse files
committed
[RISCV] Handle recurrences in RISCVVLOptimizer
1 parent 9aba342 commit ea2b861

File tree

5 files changed

+92
-55
lines changed

5 files changed

+92
-55
lines changed

llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp

Lines changed: 85 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,27 @@ using namespace llvm;
3030

3131
namespace {
3232

33+
/// Wrapper around MachineOperand that defaults to immediate 0.
34+
struct DemandedVL {
35+
MachineOperand VL;
36+
DemandedVL() : VL(MachineOperand::CreateImm(0)) {}
37+
DemandedVL(MachineOperand VL) : VL(VL) {}
38+
static DemandedVL vlmax() {
39+
return DemandedVL(MachineOperand::CreateImm(RISCV::VLMaxSentinel));
40+
}
41+
bool operator!=(const DemandedVL &Other) const {
42+
return !VL.isIdenticalTo(Other.VL);
43+
}
44+
};
45+
46+
static DemandedVL max(const DemandedVL &LHS, const DemandedVL &RHS) {
47+
if (RISCV::isVLKnownLE(LHS.VL, RHS.VL))
48+
return RHS;
49+
if (RISCV::isVLKnownLE(RHS.VL, LHS.VL))
50+
return LHS;
51+
return DemandedVL::vlmax();
52+
}
53+
3354
class RISCVVLOptimizer : public MachineFunctionPass {
3455
const MachineRegisterInfo *MRI;
3556
const MachineDominatorTree *MDT;
@@ -51,17 +72,26 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5172
StringRef getPassName() const override { return PASS_NAME; }
5273

5374
private:
54-
std::optional<MachineOperand>
55-
getMinimumVLForUser(const MachineOperand &UserOp) const;
56-
/// Returns the largest common VL MachineOperand that may be used to optimize
57-
/// MI. Returns std::nullopt if it failed to find a suitable VL.
58-
std::optional<MachineOperand> checkUsers(const MachineInstr &MI) const;
75+
DemandedVL getMinimumVLForUser(const MachineOperand &UserOp) const;
76+
/// Returns true if the users of \p MI have compatible EEWs and SEWs.
77+
bool checkUsers(const MachineInstr &MI) const;
5978
bool tryReduceVL(MachineInstr &MI) const;
6079
bool isCandidate(const MachineInstr &MI) const;
80+
void transfer(const MachineInstr &MI);
81+
82+
/// Returns all uses of vector virtual registers.
83+
auto vector_uses(const MachineInstr &MI) const {
84+
auto Pred = [this](const MachineOperand &MO) -> bool {
85+
return MO.isReg() && MO.getReg().isVirtual() &&
86+
RISCVRegisterInfo::isRVVRegClass(MRI->getRegClass(MO.getReg()));
87+
};
88+
return make_filter_range(MI.uses(), Pred);
89+
}
6190

6291
/// For a given instruction, records what elements of it are demanded by
6392
/// downstream users.
64-
DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs;
93+
DenseMap<const MachineInstr *, DemandedVL> DemandedVLs;
94+
SetVector<const MachineInstr *> Worklist;
6595
};
6696

6797
/// Represents the EMUL and EEW of a MachineOperand.
@@ -821,6 +851,9 @@ getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
821851
/// white-list approach simplifies this optimization for instructions that may
822852
/// have more complex semantics with relation to how it uses VL.
823853
static bool isSupportedInstr(const MachineInstr &MI) {
854+
if (MI.isPHI() || MI.isFullCopy())
855+
return true;
856+
824857
const RISCVVPseudosTable::PseudoInfo *RVV =
825858
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
826859

@@ -1321,21 +1354,24 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
13211354
return true;
13221355
}
13231356

1324-
std::optional<MachineOperand>
1357+
DemandedVL
13251358
RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
13261359
const MachineInstr &UserMI = *UserOp.getParent();
13271360
const MCInstrDesc &Desc = UserMI.getDesc();
13281361

1362+
if (UserMI.isPHI() || UserMI.isFullCopy())
1363+
return DemandedVLs.lookup(&UserMI);
1364+
13291365
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
13301366
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
13311367
" use VLMAX\n");
1332-
return std::nullopt;
1368+
return DemandedVL::vlmax();
13331369
}
13341370

13351371
if (RISCVII::readsPastVL(
13361372
TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags)) {
13371373
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
1338-
return std::nullopt;
1374+
return DemandedVL::vlmax();
13391375
}
13401376

13411377
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
@@ -1349,11 +1385,10 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
13491385
if (UserOp.isTied()) {
13501386
assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
13511387
RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
1352-
auto DemandedVL = DemandedVLs.lookup(&UserMI);
1353-
if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) {
1388+
if (!RISCV::isVLKnownLE(DemandedVLs.lookup(&UserMI).VL, VLOp)) {
13541389
LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
13551390
"instruction with demanded tail\n");
1356-
return std::nullopt;
1391+
return DemandedVL::vlmax();
13571392
}
13581393
}
13591394

@@ -1366,18 +1401,16 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
13661401

13671402
// If we know the demanded VL of UserMI, then we can reduce the VL it
13681403
// requires.
1369-
if (auto DemandedVL = DemandedVLs.lookup(&UserMI)) {
1370-
assert(isCandidate(UserMI));
1371-
if (RISCV::isVLKnownLE(*DemandedVL, VLOp))
1372-
return DemandedVL;
1373-
}
1404+
if (RISCV::isVLKnownLE(DemandedVLs.lookup(&UserMI).VL, VLOp))
1405+
return DemandedVLs.lookup(&UserMI);
13741406

13751407
return VLOp;
13761408
}
13771409

1378-
std::optional<MachineOperand>
1379-
RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
1380-
std::optional<MachineOperand> CommonVL;
1410+
bool RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
1411+
if (MI.isPHI() || MI.isFullCopy())
1412+
return true;
1413+
13811414
SmallSetVector<MachineOperand *, 8> Worklist;
13821415
SmallPtrSet<const MachineInstr *, 4> PHISeen;
13831416
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg()))
@@ -1405,23 +1438,9 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
14051438
continue;
14061439
}
14071440

1408-
auto VLOp = getMinimumVLForUser(UserOp);
1409-
if (!VLOp)
1410-
return std::nullopt;
1411-
1412-
// Use the largest VL among all the users. If we cannot determine this
1413-
// statically, then we cannot optimize the VL.
1414-
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
1415-
CommonVL = *VLOp;
1416-
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
1417-
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
1418-
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
1419-
return std::nullopt;
1420-
}
1421-
14221441
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
14231442
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
1424-
return std::nullopt;
1443+
return false;
14251444
}
14261445

14271446
std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
@@ -1431,7 +1450,7 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
14311450
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
14321451
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
14331452
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1434-
return std::nullopt;
1453+
return false;
14351454
}
14361455

14371456
if (!OperandInfo::areCompatible(*ProducerInfo, *ConsumerInfo)) {
@@ -1440,11 +1459,11 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
14401459
<< " Abort due to incompatible information for EMUL or EEW.\n");
14411460
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
14421461
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1443-
return std::nullopt;
1462+
return false;
14441463
}
14451464
}
14461465

1447-
return CommonVL;
1466+
return true;
14481467
}
14491468

14501469
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
@@ -1460,9 +1479,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
14601479
return false;
14611480
}
14621481

1463-
auto CommonVL = DemandedVLs.lookup(&MI);
1464-
if (!CommonVL)
1465-
return false;
1482+
auto *CommonVL = &DemandedVLs.at(&MI).VL;
14661483

14671484
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
14681485
"Expected VL to be an Imm or virtual Reg");
@@ -1497,6 +1514,24 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
14971514
return true;
14981515
}
14991516

1517+
static bool isPhysical(const MachineOperand &MO) {
1518+
return MO.isReg() && MO.getReg().isPhysical();
1519+
}
1520+
1521+
/// Look through \p MI's operands and propagate what it demands to its uses.
1522+
void RISCVVLOptimizer::transfer(const MachineInstr &MI) {
1523+
if (!isSupportedInstr(MI) || !checkUsers(MI) || any_of(MI.defs(), isPhysical))
1524+
DemandedVLs[&MI] = DemandedVL::vlmax();
1525+
1526+
for (const MachineOperand &MO : vector_uses(MI)) {
1527+
const MachineInstr *Def = MRI->getVRegDef(MO.getReg());
1528+
DemandedVL Prev = DemandedVLs[Def];
1529+
DemandedVLs[Def] = max(DemandedVLs[Def], getMinimumVLForUser(MO));
1530+
if (DemandedVLs[Def] != Prev)
1531+
Worklist.insert(Def);
1532+
}
1533+
}
1534+
15001535
bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
15011536
if (skipFunction(MF.getFunction()))
15021537
return false;
@@ -1513,14 +1548,17 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
15131548
assert(DemandedVLs.empty());
15141549

15151550
// For each instruction that defines a vector, compute what VL its
1516-
// downstream users demand.
1551+
// upstream uses demand.
15171552
for (MachineBasicBlock *MBB : post_order(&MF)) {
15181553
assert(MDT->isReachableFromEntry(MBB));
1519-
for (MachineInstr &MI : reverse(*MBB)) {
1520-
if (!isCandidate(MI))
1521-
continue;
1522-
DemandedVLs.insert({&MI, checkUsers(MI)});
1523-
}
1554+
for (MachineInstr &MI : reverse(*MBB))
1555+
Worklist.insert(&MI);
1556+
}
1557+
1558+
while (!Worklist.empty()) {
1559+
const MachineInstr *MI = Worklist.front();
1560+
Worklist.remove(MI);
1561+
transfer(*MI);
15241562
}
15251563

15261564
// Then go through and see if we can reduce the VL of any instructions to

llvm/test/CodeGen/RISCV/rvv/reproducer-pr146855.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ target triple = "riscv64-unknown-linux-gnu"
66
define i32 @_ZN4Mesh12rezone_countESt6vectorIiSaIiEERiS3_(<vscale x 4 x i32> %wide.load, <vscale x 4 x i1> %0, <vscale x 4 x i1> %1, <vscale x 4 x i1> %2, <vscale x 4 x i1> %3) #0 {
77
; CHECK-LABEL: _ZN4Mesh12rezone_countESt6vectorIiSaIiEERiS3_:
88
; CHECK: # %bb.0: # %entry
9-
; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, ma
9+
; CHECK-NEXT: vsetivli zero, 0, e32, m2, ta, ma
1010
; CHECK-NEXT: vmv1r.v v8, v0
1111
; CHECK-NEXT: li a0, 0
1212
; CHECK-NEXT: vmv.v.i v10, 0
1313
; CHECK-NEXT: vmv.v.i v12, 0
1414
; CHECK-NEXT: vmv.v.i v14, 0
1515
; CHECK-NEXT: .LBB0_1: # %vector.body
1616
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
17-
; CHECK-NEXT: vsetvli a1, zero, e32, m2, ta, mu
17+
; CHECK-NEXT: vsetivli zero, 0, e32, m2, ta, mu
1818
; CHECK-NEXT: vmv1r.v v0, v8
1919
; CHECK-NEXT: slli a0, a0, 2
2020
; CHECK-NEXT: vmv2r.v v16, v10

llvm/test/CodeGen/RISCV/rvv/vl-opt.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,14 @@ define void @fadd_fcmp_select_copy(<vscale x 4 x float> %v, <vscale x 4 x i1> %c
202202
define void @recurrence(<vscale x 4 x i32> %v, ptr %p, iXLen %n, iXLen %vl) {
203203
; CHECK-LABEL: recurrence:
204204
; CHECK: # %bb.0: # %entry
205-
; CHECK-NEXT: vsetvli a3, zero, e32, m2, ta, ma
205+
; CHECK-NEXT: vsetvli zero, a2, e32, m2, ta, ma
206206
; CHECK-NEXT: vmv.v.i v10, 0
207207
; CHECK-NEXT: .LBB13_1: # %loop
208208
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
209209
; CHECK-NEXT: addi a1, a1, -1
210210
; CHECK-NEXT: vadd.vv v10, v10, v8
211211
; CHECK-NEXT: bnez a1, .LBB13_1
212212
; CHECK-NEXT: # %bb.2: # %exit
213-
; CHECK-NEXT: vsetvli zero, a2, e32, m2, ta, ma
214213
; CHECK-NEXT: vse32.v v10, (a0)
215214
; CHECK-NEXT: ret
216215
entry:

llvm/test/CodeGen/RISCV/rvv/vl-opt.mir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,14 +613,14 @@ body: |
613613
; CHECK-NEXT: liveins: $x8
614614
; CHECK-NEXT: {{ $}}
615615
; CHECK-NEXT: %avl:gprnox0 = COPY $x8
616-
; CHECK-NEXT: %start:vr = PseudoVMV_V_I_M1 $noreg, 0, -1, 3 /* e8 */, 3 /* ta, ma */
616+
; CHECK-NEXT: %start:vr = PseudoVMV_V_I_M1 $noreg, 0, %avl, 3 /* e8 */, 3 /* ta, ma */
617617
; CHECK-NEXT: PseudoBR %bb.1
618618
; CHECK-NEXT: {{ $}}
619619
; CHECK-NEXT: bb.1:
620620
; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000)
621621
; CHECK-NEXT: {{ $}}
622622
; CHECK-NEXT: %phi:vr = PHI %start, %bb.0, %inc, %bb.1
623-
; CHECK-NEXT: %inc:vr = PseudoVADD_VI_M1 $noreg, %phi, 1, -1, 3 /* e8 */, 3 /* ta, ma */
623+
; CHECK-NEXT: %inc:vr = PseudoVADD_VI_M1 $noreg, %phi, 1, %avl, 3 /* e8 */, 3 /* ta, ma */
624624
; CHECK-NEXT: BNE $noreg, $noreg, %bb.1
625625
; CHECK-NEXT: {{ $}}
626626
; CHECK-NEXT: bb.2:

llvm/test/CodeGen/RISCV/rvv/vlopt-same-vl.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
; which was responsible for speeding it up.
1212

1313
define <vscale x 4 x i32> @same_vl_imm(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b) {
14-
; CHECK: User VL is: 4
14+
; CHECK: Trying to reduce VL for %{{.+}}:vrm2 = PseudoVADD_VV_M2
1515
; CHECK: Abort due to CommonVL == VLOp, no point in reducing.
1616
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, i64 4)
1717
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, i64 4)
1818
ret <vscale x 4 x i32> %w
1919
}
2020

2121
define <vscale x 4 x i32> @same_vl_reg(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, i64 %vl) {
22-
; CHECK: User VL is: %3:gprnox0
22+
; CHECK: Trying to reduce VL for %{{.+}}:vrm2 = PseudoVADD_VV_M2
2323
; CHECK: Abort due to CommonVL == VLOp, no point in reducing.
2424
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, i64 %vl)
2525
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, i64 %vl)

0 commit comments

Comments
 (0)