@@ -33,6 +33,7 @@ namespace {
3333class RISCVVLOptimizer : public MachineFunctionPass {
3434 const MachineRegisterInfo *MRI;
3535 const MachineDominatorTree *MDT;
36+ const TargetInstrInfo *TII;
3637
3738public:
3839 static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5051 StringRef getPassName () const override { return PASS_NAME; }
5152
5253private:
53- std::optional<MachineOperand> getMinimumVLForUser (MachineOperand &UserOp);
54- // / Returns the largest common VL MachineOperand that may be used to optimize
55- // / MI. Returns std::nullopt if it failed to find a suitable VL.
56- std::optional<MachineOperand> checkUsers (MachineInstr &MI);
54+ MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55+ // / Computes the VL of \p MI that is actually used by its users.
56+ MachineOperand computeDemandedVL (const MachineInstr &MI);
5757 bool tryReduceVL (MachineInstr &MI);
5858 bool isCandidate (const MachineInstr &MI) const ;
59+
60+ // / For a given instruction, records what elements of it are demanded by
61+ // / downstream users.
62+ DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
5963};
6064
6165} // end anonymous namespace
@@ -1173,15 +1177,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
11731177 return true ;
11741178}
11751179
1176- std::optional<MachineOperand>
1177- RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1180+ MachineOperand RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
11781181 const MachineInstr &UserMI = *UserOp.getParent ();
11791182 const MCInstrDesc &Desc = UserMI.getDesc ();
11801183
11811184 if (!RISCVII::hasVLOp (Desc.TSFlags ) || !RISCVII::hasSEWOp (Desc.TSFlags )) {
11821185 LLVM_DEBUG (dbgs () << " Abort due to lack of VL, assume that"
11831186 " use VLMAX\n " );
1184- return std:: nullopt ;
1187+ return MachineOperand::CreateImm (RISCV::VLMaxSentinel) ;
11851188 }
11861189
11871190 // Instructions like reductions may use a vector register as a scalar
@@ -1201,46 +1204,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12011204 // Looking for an immediate or a register VL that isn't X0.
12021205 assert ((!VLOp.isReg () || VLOp.getReg () != RISCV::X0) &&
12031206 " Did not expect X0 VL" );
1207+
1208+ // If we know the demanded VL of UserMI, then we can reduce the VL it
1209+ // requires.
1210+ if (DemandedVLs.contains (&UserMI)) {
1211+ // We can only shrink the demanded VL if the elementwise result doesn't
1212+ // depend on VL (i.e. not vredsum/viota etc.)
1213+ // Also conservatively restrict to supported instructions for now.
1214+ // TODO: Can we remove the isSupportedInstr check?
1215+ if (!RISCVII::elementsDependOnVL (
1216+ TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags ) &&
1217+ isSupportedInstr (UserMI)) {
1218+ const MachineOperand &DemandedVL = DemandedVLs.at (&UserMI);
1219+ if (RISCV::isVLKnownLE (DemandedVL, VLOp))
1220+ return DemandedVL;
1221+ }
1222+ }
1223+
12041224 return VLOp;
12051225}
12061226
1207- std::optional<MachineOperand> RISCVVLOptimizer::checkUsers (MachineInstr &MI) {
1208- // FIXME: Avoid visiting each user for each time we visit something on the
1209- // worklist, combined with an extra visit from the outer loop. Restructure
1210- // along lines of an instcombine style worklist which integrates the outer
1211- // pass.
1212- std::optional<MachineOperand> CommonVL;
1227+ MachineOperand RISCVVLOptimizer::computeDemandedVL (const MachineInstr &MI) {
1228+ const MachineOperand &VLMAX = MachineOperand::CreateImm (RISCV::VLMaxSentinel);
1229+ MachineOperand DemandedVL = MachineOperand::CreateImm (0 );
1230+
12131231 for (auto &UserOp : MRI->use_operands (MI.getOperand (0 ).getReg ())) {
12141232 const MachineInstr &UserMI = *UserOp.getParent ();
12151233 LLVM_DEBUG (dbgs () << " Checking user: " << UserMI << " \n " );
12161234 if (mayReadPastVL (UserMI)) {
12171235 LLVM_DEBUG (dbgs () << " Abort because used by unsafe instruction\n " );
1218- return std:: nullopt ;
1236+ return VLMAX ;
12191237 }
12201238
12211239 // If used as a passthru, elements past VL will be read.
12221240 if (UserOp.isTied ()) {
12231241 LLVM_DEBUG (dbgs () << " Abort because user used as tied operand\n " );
1224- return std:: nullopt ;
1242+ return VLMAX ;
12251243 }
12261244
1227- auto VLOp = getMinimumVLForUser (UserOp);
1228- if (!VLOp)
1229- return std::nullopt ;
1245+ const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
12301246
12311247 // Use the largest VL among all the users. If we cannot determine this
12321248 // statically, then we cannot optimize the VL.
1233- if (!CommonVL || RISCV::isVLKnownLE (*CommonVL, * VLOp)) {
1234- CommonVL = * VLOp;
1235- LLVM_DEBUG (dbgs () << " User VL is: " << VLOp << " \n " );
1236- } else if (!RISCV::isVLKnownLE (* VLOp, *CommonVL )) {
1249+ if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
1250+ DemandedVL = VLOp;
1251+ LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
1252+ } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL )) {
12371253 LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1238- return std:: nullopt ;
1254+ return VLMAX ;
12391255 }
12401256
12411257 if (!RISCVII::hasSEWOp (UserMI.getDesc ().TSFlags )) {
12421258 LLVM_DEBUG (dbgs () << " Abort due to lack of SEW operand\n " );
1243- return std:: nullopt ;
1259+ return VLMAX ;
12441260 }
12451261
12461262 std::optional<OperandInfo> ConsumerInfo = getOperandInfo (UserOp, MRI);
@@ -1250,7 +1266,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12501266 LLVM_DEBUG (dbgs () << " Abort due to unknown operand information.\n " );
12511267 LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
12521268 LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1253- return std:: nullopt ;
1269+ return VLMAX ;
12541270 }
12551271
12561272 // If the operand is used as a scalar operand, then the EEW must be
@@ -1265,53 +1281,51 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12651281 << " Abort due to incompatible information for EMUL or EEW.\n " );
12661282 LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
12671283 LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1268- return std:: nullopt ;
1284+ return VLMAX ;
12691285 }
12701286 }
12711287
1272- return CommonVL ;
1288+ return DemandedVL ;
12731289}
12741290
12751291bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
12761292 LLVM_DEBUG (dbgs () << " Trying to reduce VL for " << MI << " \n " );
12771293
1278- auto CommonVL = checkUsers (MI);
1279- if (!CommonVL)
1280- return false ;
1294+ const MachineOperand &CommonVL = DemandedVLs.at (&MI);
12811295
1282- assert ((CommonVL-> isImm () || CommonVL-> getReg ().isVirtual ()) &&
1296+ assert ((CommonVL. isImm () || CommonVL. getReg ().isVirtual ()) &&
12831297 " Expected VL to be an Imm or virtual Reg" );
12841298
12851299 unsigned VLOpNum = RISCVII::getVLOpNum (MI.getDesc ());
12861300 MachineOperand &VLOp = MI.getOperand (VLOpNum);
12871301
1288- if (!RISCV::isVLKnownLE (* CommonVL, VLOp)) {
1289- LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
1302+ if (!RISCV::isVLKnownLE (CommonVL, VLOp)) {
1303+ LLVM_DEBUG (dbgs () << " Abort due to DemandedVL not <= VLOp.\n " );
12901304 return false ;
12911305 }
12921306
1293- if (CommonVL-> isIdenticalTo (VLOp)) {
1307+ if (CommonVL. isIdenticalTo (VLOp)) {
12941308 LLVM_DEBUG (
1295- dbgs () << " Abort due to CommonVL == VLOp, no point in reducing.\n " );
1309+ dbgs ()
1310+ << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
12961311 return false ;
12971312 }
12981313
1299- if (CommonVL-> isImm ()) {
1314+ if (CommonVL. isImm ()) {
13001315 LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1301- << CommonVL-> getImm () << " for " << MI << " \n " );
1302- VLOp.ChangeToImmediate (CommonVL-> getImm ());
1316+ << CommonVL. getImm () << " for " << MI << " \n " );
1317+ VLOp.ChangeToImmediate (CommonVL. getImm ());
13031318 return true ;
13041319 }
1305- const MachineInstr *VLMI = MRI->getVRegDef (CommonVL-> getReg ());
1320+ const MachineInstr *VLMI = MRI->getVRegDef (CommonVL. getReg ());
13061321 if (!MDT->dominates (VLMI, &MI))
13071322 return false ;
1308- LLVM_DEBUG (
1309- dbgs () << " Reduce VL from " << VLOp << " to "
1310- << printReg (CommonVL->getReg (), MRI->getTargetRegisterInfo ())
1311- << " for " << MI << " \n " );
1323+ LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1324+ << printReg (CommonVL.getReg (), MRI->getTargetRegisterInfo ())
1325+ << " for " << MI << " \n " );
13121326
13131327 // All our checks passed. We can reduce VL.
1314- VLOp.ChangeToRegister (CommonVL-> getReg (), false );
1328+ VLOp.ChangeToRegister (CommonVL. getReg (), false );
13151329 return true ;
13161330}
13171331
@@ -1326,52 +1340,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
13261340 if (!ST.hasVInstructions ())
13271341 return false ;
13281342
1329- SetVector<MachineInstr *> Worklist;
1330- auto PushOperands = [this , &Worklist](MachineInstr &MI,
1331- bool IgnoreSameBlock) {
1332- for (auto &Op : MI.operands ()) {
1333- if (!Op.isReg () || !Op.isUse () || !Op.getReg ().isVirtual () ||
1334- !isVectorRegClass (Op.getReg (), MRI))
1335- continue ;
1336-
1337- MachineInstr *DefMI = MRI->getVRegDef (Op.getReg ());
1338- if (!isCandidate (*DefMI))
1339- continue ;
1340-
1341- if (IgnoreSameBlock && DefMI->getParent () == MI.getParent ())
1342- continue ;
1343-
1344- Worklist.insert (DefMI);
1345- }
1346- };
1343+ TII = ST.getInstrInfo ();
13471344
1348- // Do a first pass eagerly rewriting in roughly reverse instruction
1349- // order, populate the worklist with any instructions we might need to
1350- // revisit. We avoid adding definitions to the worklist if they're
1351- // in the same block - we're about to visit them anyways.
13521345 bool MadeChange = false ;
13531346 for (MachineBasicBlock &MBB : MF) {
13541347 // Avoid unreachable blocks as they have degenerate dominance
13551348 if (!MDT->isReachableFromEntry (&MBB))
13561349 continue ;
13571350
1358- for (auto &MI : reverse (MBB)) {
1351+ // For each instruction that defines a vector, compute what VL its
1352+ // downstream users demand.
1353+ for (const auto &MI : reverse (MBB)) {
1354+ if (!isCandidate (MI))
1355+ continue ;
1356+ DemandedVLs.insert ({&MI, computeDemandedVL (MI)});
1357+ }
1358+
1359+ // Then go through and see if we can reduce the VL of any instructions to
1360+ // only what's demanded.
1361+ for (auto &MI : MBB) {
13591362 if (!isCandidate (MI))
13601363 continue ;
13611364 if (!tryReduceVL (MI))
13621365 continue ;
13631366 MadeChange = true ;
1364- PushOperands (MI, /* IgnoreSameBlock*/ true );
13651367 }
1366- }
13671368
1368- while (!Worklist.empty ()) {
1369- assert (MadeChange);
1370- MachineInstr &MI = *Worklist.pop_back_val ();
1371- assert (isCandidate (MI));
1372- if (!tryReduceVL (MI))
1373- continue ;
1374- PushOperands (MI, /* IgnoreSameBlock*/ false );
1369+ DemandedVLs.clear ();
13751370 }
13761371
13771372 return MadeChange;
0 commit comments