@@ -33,6 +33,7 @@ namespace {
3333class RISCVVLOptimizer : public MachineFunctionPass {
3434 const MachineRegisterInfo *MRI;
3535 const MachineDominatorTree *MDT;
36+ const TargetInstrInfo *TII;
3637
3738public:
3839 static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5051 StringRef getPassName () const override { return PASS_NAME; }
5152
5253private:
53- std::optional<MachineOperand> getMinimumVLForUser (MachineOperand &UserOp);
54- // / Returns the largest common VL MachineOperand that may be used to optimize
55- // / MI. Returns std::nullopt if it failed to find a suitable VL.
56- std::optional<MachineOperand> checkUsers (MachineInstr &MI);
54+ MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55+ // / Computes the VL of \p MI that is actually used by its users.
56+ MachineOperand computeDemandedVL (const MachineInstr &MI);
5757 bool tryReduceVL (MachineInstr &MI);
5858 bool isCandidate (const MachineInstr &MI) const ;
59+
60+ // / For a given instruction, records what elements of it are demanded by
61+ // / downstream users.
62+ DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
5963};
6064
6165} // end anonymous namespace
@@ -1173,15 +1177,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
11731177 return true ;
11741178}
11751179
1176- std::optional<MachineOperand>
1177- RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1180+ MachineOperand RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
11781181 const MachineInstr &UserMI = *UserOp.getParent ();
11791182 const MCInstrDesc &Desc = UserMI.getDesc ();
11801183
11811184 if (!RISCVII::hasVLOp (Desc.TSFlags ) || !RISCVII::hasSEWOp (Desc.TSFlags )) {
11821185 LLVM_DEBUG (dbgs () << " Abort due to lack of VL, assume that"
11831186 " use VLMAX\n " );
1184- return std:: nullopt ;
1187+ return MachineOperand::CreateImm (RISCV::VLMaxSentinel) ;
11851188 }
11861189
11871190 // Instructions like reductions may use a vector register as a scalar
@@ -1201,46 +1204,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12011204 // Looking for an immediate or a register VL that isn't X0.
12021205 assert ((!VLOp.isReg () || VLOp.getReg () != RISCV::X0) &&
12031206 " Did not expect X0 VL" );
1207+
1208+ // If we know the demanded VL of UserMI, then we can reduce the VL it
1209+ // requires.
1210+ if (DemandedVLs.contains (&UserMI)) {
1211+ // We can only shrink the demanded VL if the elementwise result doesn't
1212+ // depend on VL (i.e. not vredsum/viota etc.)
1213+ // Also conservatively restrict to supported instructions for now.
1214+ // TODO: Can we remove the isSupportedInstr check?
1215+ if (!RISCVII::elementsDependOnVL (
1216+ TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags ) &&
1217+ isSupportedInstr (UserMI)) {
1218+ const MachineOperand &DemandedVL = DemandedVLs.at (&UserMI);
1219+ if (RISCV::isVLKnownLE (DemandedVL, VLOp))
1220+ return DemandedVL;
1221+ }
1222+ }
1223+
12041224 return VLOp;
12051225}
12061226
1207- std::optional<MachineOperand> RISCVVLOptimizer::checkUsers (MachineInstr &MI) {
1208- // FIXME: Avoid visiting each user for each time we visit something on the
1209- // worklist, combined with an extra visit from the outer loop. Restructure
1210- // along lines of an instcombine style worklist which integrates the outer
1211- // pass.
1212- std::optional<MachineOperand> CommonVL;
1227+ MachineOperand RISCVVLOptimizer::computeDemandedVL (const MachineInstr &MI) {
1228+ const MachineOperand &VLMAX = MachineOperand::CreateImm (RISCV::VLMaxSentinel);
1229+ MachineOperand DemandedVL = MachineOperand::CreateImm (0 );
1230+
12131231 for (auto &UserOp : MRI->use_operands (MI.getOperand (0 ).getReg ())) {
12141232 const MachineInstr &UserMI = *UserOp.getParent ();
12151233 LLVM_DEBUG (dbgs () << " Checking user: " << UserMI << " \n " );
12161234 if (mayReadPastVL (UserMI)) {
12171235 LLVM_DEBUG (dbgs () << " Abort because used by unsafe instruction\n " );
1218- return std:: nullopt ;
1236+ return VLMAX ;
12191237 }
12201238
12211239 // If used as a passthru, elements past VL will be read.
12221240 if (UserOp.isTied ()) {
12231241 LLVM_DEBUG (dbgs () << " Abort because user used as tied operand\n " );
1224- return std:: nullopt ;
1242+ return VLMAX ;
12251243 }
12261244
1227- auto VLOp = getMinimumVLForUser (UserOp);
1228- if (!VLOp)
1229- return std::nullopt ;
1245+ const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
12301246
12311247 // Use the largest VL among all the users. If we cannot determine this
12321248 // statically, then we cannot optimize the VL.
1233- if (!CommonVL || RISCV::isVLKnownLE (*CommonVL, * VLOp)) {
1234- CommonVL = * VLOp;
1235- LLVM_DEBUG (dbgs () << " User VL is: " << VLOp << " \n " );
1236- } else if (!RISCV::isVLKnownLE (* VLOp, *CommonVL )) {
1249+ if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
1250+ DemandedVL = VLOp;
1251+ LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
1252+ } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL )) {
12371253 LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1238- return std:: nullopt ;
1254+ return VLMAX ;
12391255 }
12401256
12411257 if (!RISCVII::hasSEWOp (UserMI.getDesc ().TSFlags )) {
12421258 LLVM_DEBUG (dbgs () << " Abort due to lack of SEW operand\n " );
1243- return std:: nullopt ;
1259+ return VLMAX ;
12441260 }
12451261
12461262 std::optional<OperandInfo> ConsumerInfo = getOperandInfo (UserOp, MRI);
@@ -1250,7 +1266,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12501266 LLVM_DEBUG (dbgs () << " Abort due to unknown operand information.\n " );
12511267 LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
12521268 LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1253- return std:: nullopt ;
1269+ return VLMAX ;
12541270 }
12551271
12561272 // If the operand is used as a scalar operand, then the EEW must be
@@ -1265,11 +1281,11 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12651281 << " Abort due to incompatible information for EMUL or EEW.\n " );
12661282 LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
12671283 LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1268- return std:: nullopt ;
1284+ return VLMAX ;
12691285 }
12701286 }
12711287
1272- return CommonVL ;
1288+ return DemandedVL ;
12731289}
12741290
12751291bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
@@ -1285,40 +1301,40 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
12851301 return false ;
12861302 }
12871303
1288- auto CommonVL = checkUsers (MI) ;
1304+ auto CommonVL = DemandedVLs[&MI] ;
12891305 if (!CommonVL)
12901306 return false ;
12911307
1292- assert ((CommonVL-> isImm () || CommonVL-> getReg ().isVirtual ()) &&
1308+ assert ((CommonVL. isImm () || CommonVL. getReg ().isVirtual ()) &&
12931309 " Expected VL to be an Imm or virtual Reg" );
12941310
12951311 if (!RISCV::isVLKnownLE (*CommonVL, VLOp)) {
12961312 LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
12971313 return false ;
12981314 }
12991315
1300- if (CommonVL-> isIdenticalTo (VLOp)) {
1316+ if (CommonVL. isIdenticalTo (VLOp)) {
13011317 LLVM_DEBUG (
1302- dbgs () << " Abort due to CommonVL == VLOp, no point in reducing.\n " );
1318+ dbgs ()
1319+ << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
13031320 return false ;
13041321 }
13051322
1306- if (CommonVL-> isImm ()) {
1323+ if (CommonVL. isImm ()) {
13071324 LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1308- << CommonVL-> getImm () << " for " << MI << " \n " );
1309- VLOp.ChangeToImmediate (CommonVL-> getImm ());
1325+ << CommonVL. getImm () << " for " << MI << " \n " );
1326+ VLOp.ChangeToImmediate (CommonVL. getImm ());
13101327 return true ;
13111328 }
1312- const MachineInstr *VLMI = MRI->getVRegDef (CommonVL-> getReg ());
1329+ const MachineInstr *VLMI = MRI->getVRegDef (CommonVL. getReg ());
13131330 if (!MDT->dominates (VLMI, &MI))
13141331 return false ;
1315- LLVM_DEBUG (
1316- dbgs () << " Reduce VL from " << VLOp << " to "
1317- << printReg (CommonVL->getReg (), MRI->getTargetRegisterInfo ())
1318- << " for " << MI << " \n " );
1332+ LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1333+ << printReg (CommonVL.getReg (), MRI->getTargetRegisterInfo ())
1334+ << " for " << MI << " \n " );
13191335
13201336 // All our checks passed. We can reduce VL.
1321- VLOp.ChangeToRegister (CommonVL-> getReg (), false );
1337+ VLOp.ChangeToRegister (CommonVL. getReg (), false );
13221338 return true ;
13231339}
13241340
@@ -1333,52 +1349,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
13331349 if (!ST.hasVInstructions ())
13341350 return false ;
13351351
1336- SetVector<MachineInstr *> Worklist;
1337- auto PushOperands = [this , &Worklist](MachineInstr &MI,
1338- bool IgnoreSameBlock) {
1339- for (auto &Op : MI.operands ()) {
1340- if (!Op.isReg () || !Op.isUse () || !Op.getReg ().isVirtual () ||
1341- !isVectorRegClass (Op.getReg (), MRI))
1342- continue ;
1352+ TII = ST.getInstrInfo ();
13431353
1344- MachineInstr *DefMI = MRI->getVRegDef (Op.getReg ());
1345- if (!isCandidate (*DefMI))
1346- continue ;
1347-
1348- if (IgnoreSameBlock && DefMI->getParent () == MI.getParent ())
1349- continue ;
1350-
1351- Worklist.insert (DefMI);
1352- }
1353- };
1354-
1355- // Do a first pass eagerly rewriting in roughly reverse instruction
1356- // order, populate the worklist with any instructions we might need to
1357- // revisit. We avoid adding definitions to the worklist if they're
1358- // in the same block - we're about to visit them anyways.
13591354 bool MadeChange = false ;
13601355 for (MachineBasicBlock &MBB : MF) {
13611356 // Avoid unreachable blocks as they have degenerate dominance
13621357 if (!MDT->isReachableFromEntry (&MBB))
13631358 continue ;
13641359
1365- for (auto &MI : reverse (MBB)) {
1360+ // For each instruction that defines a vector, compute what VL its
1361+ // downstream users demand.
1362+ for (const auto &MI : reverse (MBB)) {
1363+ if (!isCandidate (MI))
1364+ continue ;
1365+ DemandedVLs.insert ({&MI, computeDemandedVL (MI)});
1366+ }
1367+
1368+ // Then go through and see if we can reduce the VL of any instructions to
1369+ // only what's demanded.
1370+ for (auto &MI : MBB) {
13661371 if (!isCandidate (MI))
13671372 continue ;
13681373 if (!tryReduceVL (MI))
13691374 continue ;
13701375 MadeChange = true ;
1371- PushOperands (MI, /* IgnoreSameBlock*/ true );
13721376 }
1373- }
13741377
1375- while (!Worklist.empty ()) {
1376- assert (MadeChange);
1377- MachineInstr &MI = *Worklist.pop_back_val ();
1378- assert (isCandidate (MI));
1379- if (!tryReduceVL (MI))
1380- continue ;
1381- PushOperands (MI, /* IgnoreSameBlock*/ false );
1378+ DemandedVLs.clear ();
13821379 }
13831380
13841381 return MadeChange;
0 commit comments