@@ -33,6 +33,7 @@ namespace {
3333class RISCVVLOptimizer : public MachineFunctionPass {
3434 const MachineRegisterInfo *MRI;
3535 const MachineDominatorTree *MDT;
36+ const TargetInstrInfo *TII;
3637
3738public:
3839 static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5051 StringRef getPassName () const override { return PASS_NAME; }
5152
5253private:
53- std::optional<MachineOperand> getMinimumVLForUser (MachineOperand &UserOp);
54- // / Returns the largest common VL MachineOperand that may be used to optimize
55- // / MI. Returns std::nullopt if it failed to find a suitable VL.
56- std::optional<MachineOperand> checkUsers (MachineInstr &MI);
54+ MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55+ // / Computes the VL of \p MI that is actually used by its users.
56+ MachineOperand computeDemandedVL (const MachineInstr &MI);
5757 bool tryReduceVL (MachineInstr &MI);
5858 bool isCandidate (const MachineInstr &MI) const ;
59+
60+ // / For a given instruction, records what elements of it are demanded by
61+ // / downstream users.
62+ DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
5963};
6064
6165} // end anonymous namespace
@@ -1202,15 +1206,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
12021206 return true ;
12031207}
12041208
1205- std::optional<MachineOperand>
1206- RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1209+ MachineOperand RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
12071210 const MachineInstr &UserMI = *UserOp.getParent ();
12081211 const MCInstrDesc &Desc = UserMI.getDesc ();
12091212
12101213 if (!RISCVII::hasVLOp (Desc.TSFlags ) || !RISCVII::hasSEWOp (Desc.TSFlags )) {
12111214 LLVM_DEBUG (dbgs () << " Abort due to lack of VL, assume that"
12121215 " use VLMAX\n " );
1213- return std:: nullopt ;
1216+ return MachineOperand::CreateImm (RISCV::VLMaxSentinel) ;
12141217 }
12151218
12161219 // Instructions like reductions may use a vector register as a scalar
@@ -1230,46 +1233,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12301233 // Looking for an immediate or a register VL that isn't X0.
12311234 assert ((!VLOp.isReg () || VLOp.getReg () != RISCV::X0) &&
12321235 " Did not expect X0 VL" );
1236+
1237+ // If we know the demanded VL of UserMI, then we can reduce the VL it
1238+ // requires.
1239+ if (DemandedVLs.contains (&UserMI)) {
1240+ // We can only shrink the demanded VL if the elementwise result doesn't
1241+ // depend on VL (i.e. not vredsum/viota etc.)
1242+ // Also conservatively restrict to supported instructions for now.
1243+ // TODO: Can we remove the isSupportedInstr check?
1244+ if (!RISCVII::elementsDependOnVL (
1245+ TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags ) &&
1246+ isSupportedInstr (UserMI)) {
1247+ const MachineOperand &DemandedVL = DemandedVLs.at (&UserMI);
1248+ if (RISCV::isVLKnownLE (DemandedVL, VLOp))
1249+ return DemandedVL;
1250+ }
1251+ }
1252+
12331253 return VLOp;
12341254}
12351255
1236- std::optional<MachineOperand> RISCVVLOptimizer::checkUsers (MachineInstr &MI) {
1237- // FIXME: Avoid visiting each user for each time we visit something on the
1238- // worklist, combined with an extra visit from the outer loop. Restructure
1239- // along lines of an instcombine style worklist which integrates the outer
1240- // pass.
1241- std::optional<MachineOperand> CommonVL;
1256+ MachineOperand RISCVVLOptimizer::computeDemandedVL (const MachineInstr &MI) {
1257+ const MachineOperand &VLMAX = MachineOperand::CreateImm (RISCV::VLMaxSentinel);
1258+ MachineOperand DemandedVL = MachineOperand::CreateImm (0 );
1259+
12421260 for (auto &UserOp : MRI->use_operands (MI.getOperand (0 ).getReg ())) {
12431261 const MachineInstr &UserMI = *UserOp.getParent ();
12441262 LLVM_DEBUG (dbgs () << " Checking user: " << UserMI << " \n " );
12451263 if (mayReadPastVL (UserMI)) {
12461264 LLVM_DEBUG (dbgs () << " Abort because used by unsafe instruction\n " );
1247- return std:: nullopt ;
1265+ return VLMAX ;
12481266 }
12491267
12501268 // Tied operands might pass through.
12511269 if (UserOp.isTied ()) {
12521270 LLVM_DEBUG (dbgs () << " Abort because user used as tied operand\n " );
1253- return std:: nullopt ;
1271+ return VLMAX ;
12541272 }
12551273
1256- auto VLOp = getMinimumVLForUser (UserOp);
1257- if (!VLOp)
1258- return std::nullopt ;
1274+ const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
12591275
12601276 // Use the largest VL among all the users. If we cannot determine this
12611277 // statically, then we cannot optimize the VL.
1262- if (!CommonVL || RISCV::isVLKnownLE (*CommonVL, * VLOp)) {
1263- CommonVL = * VLOp;
1264- LLVM_DEBUG (dbgs () << " User VL is: " << VLOp << " \n " );
1265- } else if (!RISCV::isVLKnownLE (* VLOp, *CommonVL )) {
1278+ if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
1279+ DemandedVL = VLOp;
1280+ LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
1281+ } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL )) {
12661282 LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1267- return std:: nullopt ;
1283+ return VLMAX ;
12681284 }
12691285
12701286 if (!RISCVII::hasSEWOp (UserMI.getDesc ().TSFlags )) {
12711287 LLVM_DEBUG (dbgs () << " Abort due to lack of SEW operand\n " );
1272- return std:: nullopt ;
1288+ return VLMAX ;
12731289 }
12741290
12751291 std::optional<OperandInfo> ConsumerInfo = getOperandInfo (UserOp, MRI);
@@ -1279,7 +1295,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12791295 LLVM_DEBUG (dbgs () << " Abort due to unknown operand information.\n " );
12801296 LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
12811297 LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1282- return std:: nullopt ;
1298+ return VLMAX ;
12831299 }
12841300
12851301 // If the operand is used as a scalar operand, then the EEW must be
@@ -1294,53 +1310,51 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12941310 << " Abort due to incompatible information for EMUL or EEW.\n " );
12951311 LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
12961312 LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1297- return std:: nullopt ;
1313+ return VLMAX ;
12981314 }
12991315 }
13001316
1301- return CommonVL ;
1317+ return DemandedVL ;
13021318}
13031319
13041320bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
13051321 LLVM_DEBUG (dbgs () << " Trying to reduce VL for " << MI << " \n " );
13061322
1307- auto CommonVL = checkUsers (MI);
1308- if (!CommonVL)
1309- return false ;
1323+ const MachineOperand &CommonVL = DemandedVLs.at (&MI);
13101324
1311- assert ((CommonVL-> isImm () || CommonVL-> getReg ().isVirtual ()) &&
1325+ assert ((CommonVL. isImm () || CommonVL. getReg ().isVirtual ()) &&
13121326 " Expected VL to be an Imm or virtual Reg" );
13131327
13141328 unsigned VLOpNum = RISCVII::getVLOpNum (MI.getDesc ());
13151329 MachineOperand &VLOp = MI.getOperand (VLOpNum);
13161330
1317- if (!RISCV::isVLKnownLE (* CommonVL, VLOp)) {
1318- LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
1331+ if (!RISCV::isVLKnownLE (CommonVL, VLOp)) {
1332+ LLVM_DEBUG (dbgs () << " Abort due to DemandedVL not <= VLOp.\n " );
13191333 return false ;
13201334 }
13211335
1322- if (CommonVL-> isIdenticalTo (VLOp)) {
1336+ if (CommonVL. isIdenticalTo (VLOp)) {
13231337 LLVM_DEBUG (
1324- dbgs () << " Abort due to CommonVL == VLOp, no point in reducing.\n " );
1338+ dbgs ()
1339+ << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
13251340 return false ;
13261341 }
13271342
1328- if (CommonVL-> isImm ()) {
1343+ if (CommonVL. isImm ()) {
13291344 LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1330- << CommonVL-> getImm () << " for " << MI << " \n " );
1331- VLOp.ChangeToImmediate (CommonVL-> getImm ());
1345+ << CommonVL. getImm () << " for " << MI << " \n " );
1346+ VLOp.ChangeToImmediate (CommonVL. getImm ());
13321347 return true ;
13331348 }
1334- const MachineInstr *VLMI = MRI->getVRegDef (CommonVL-> getReg ());
1349+ const MachineInstr *VLMI = MRI->getVRegDef (CommonVL. getReg ());
13351350 if (!MDT->dominates (VLMI, &MI))
13361351 return false ;
1337- LLVM_DEBUG (
1338- dbgs () << " Reduce VL from " << VLOp << " to "
1339- << printReg (CommonVL->getReg (), MRI->getTargetRegisterInfo ())
1340- << " for " << MI << " \n " );
1352+ LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1353+ << printReg (CommonVL.getReg (), MRI->getTargetRegisterInfo ())
1354+ << " for " << MI << " \n " );
13411355
13421356 // All our checks passed. We can reduce VL.
1343- VLOp.ChangeToRegister (CommonVL-> getReg (), false );
1357+ VLOp.ChangeToRegister (CommonVL. getReg (), false );
13441358 return true ;
13451359}
13461360
@@ -1355,52 +1369,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
13551369 if (!ST.hasVInstructions ())
13561370 return false ;
13571371
1358- SetVector<MachineInstr *> Worklist;
1359- auto PushOperands = [this , &Worklist](MachineInstr &MI,
1360- bool IgnoreSameBlock) {
1361- for (auto &Op : MI.operands ()) {
1362- if (!Op.isReg () || !Op.isUse () || !Op.getReg ().isVirtual () ||
1363- !isVectorRegClass (Op.getReg (), MRI))
1364- continue ;
1365-
1366- MachineInstr *DefMI = MRI->getVRegDef (Op.getReg ());
1367- if (!isCandidate (*DefMI))
1368- continue ;
1369-
1370- if (IgnoreSameBlock && DefMI->getParent () == MI.getParent ())
1371- continue ;
1372-
1373- Worklist.insert (DefMI);
1374- }
1375- };
1372+ TII = ST.getInstrInfo ();
13761373
1377- // Do a first pass eagerly rewriting in roughly reverse instruction
1378- // order, populate the worklist with any instructions we might need to
1379- // revisit. We avoid adding definitions to the worklist if they're
1380- // in the same block - we're about to visit them anyways.
13811374 bool MadeChange = false ;
13821375 for (MachineBasicBlock &MBB : MF) {
13831376 // Avoid unreachable blocks as they have degenerate dominance
13841377 if (!MDT->isReachableFromEntry (&MBB))
13851378 continue ;
13861379
1387- for (auto &MI : make_range (MBB.rbegin (), MBB.rend ())) {
1380+ // For each instruction that defines a vector, compute what VL its
1381+ // downstream users demand.
1382+ for (const auto &MI : reverse (MBB)) {
1383+ if (!isCandidate (MI))
1384+ continue ;
1385+ DemandedVLs.insert ({&MI, computeDemandedVL (MI)});
1386+ }
1387+
1388+ // Then go through and see if we can reduce the VL of any instructions to
1389+ // only what's demanded.
1390+ for (auto &MI : MBB) {
13881391 if (!isCandidate (MI))
13891392 continue ;
13901393 if (!tryReduceVL (MI))
13911394 continue ;
13921395 MadeChange = true ;
1393- PushOperands (MI, /* IgnoreSameBlock*/ true );
13941396 }
1395- }
13961397
1397- while (!Worklist.empty ()) {
1398- assert (MadeChange);
1399- MachineInstr &MI = *Worklist.pop_back_val ();
1400- assert (isCandidate (MI));
1401- if (!tryReduceVL (MI))
1402- continue ;
1403- PushOperands (MI, /* IgnoreSameBlock*/ false );
1398+ DemandedVLs.clear ();
14041399 }
14051400
14061401 return MadeChange;
0 commit comments