@@ -51,11 +51,10 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5151 StringRef getPassName () const override { return PASS_NAME; }
5252
5353private:
54- MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55- // / Computes the minimum demanded VL of \p MI, i.e. the minimum VL that's used
56- // / by its users downstream.
57- // / Returns 0 if MI has no users.
58- MachineOperand computeDemandedVL (const MachineInstr &MI);
54+ std::optional<MachineOperand> getMinimumVLForUser (MachineOperand &UserOp);
55+ // / Returns the largest common VL MachineOperand that may be used to optimize
56+ // / MI. Returns std::nullopt if it failed to find a suitable VL.
57+ std::optional<MachineOperand> checkUsers (MachineInstr &MI);
5958 bool tryReduceVL (MachineInstr &MI);
6059 bool isCandidate (const MachineInstr &MI) const ;
6160
@@ -1179,14 +1178,15 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
11791178 return true ;
11801179}
11811180
1182- MachineOperand RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1181+ std::optional<MachineOperand>
1182+ RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
11831183 const MachineInstr &UserMI = *UserOp.getParent ();
11841184 const MCInstrDesc &Desc = UserMI.getDesc ();
11851185
11861186 if (!RISCVII::hasVLOp (Desc.TSFlags ) || !RISCVII::hasSEWOp (Desc.TSFlags )) {
11871187 LLVM_DEBUG (dbgs () << " Abort due to lack of VL, assume that"
11881188 " use VLMAX\n " );
1189- return MachineOperand::CreateImm (RISCV::VLMaxSentinel) ;
1189+ return std:: nullopt ;
11901190 }
11911191
11921192 // Instructions like reductions may use a vector register as a scalar
@@ -1223,40 +1223,39 @@ MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12231223 return VLOp;
12241224}
12251225
1226- MachineOperand RISCVVLOptimizer::computeDemandedVL (const MachineInstr &MI) {
1227- const MachineOperand &VLMAX = MachineOperand::CreateImm (RISCV::VLMaxSentinel);
1228- MachineOperand DemandedVL = MachineOperand::CreateImm (0 );
1229-
1226+ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers (MachineInstr &MI) {
1227+ std::optional<MachineOperand> CommonVL;
12301228 for (auto &UserOp : MRI->use_operands (MI.getOperand (0 ).getReg ())) {
12311229 const MachineInstr &UserMI = *UserOp.getParent ();
12321230 LLVM_DEBUG (dbgs () << " Checking user: " << UserMI << " \n " );
12331231 if (mayReadPastVL (UserMI)) {
12341232 LLVM_DEBUG (dbgs () << " Abort because used by unsafe instruction\n " );
1235- return VLMAX ;
1233+ return std:: nullopt ;
12361234 }
12371235
12381236 // If used as a passthru, elements past VL will be read.
12391237 if (UserOp.isTied ()) {
12401238 LLVM_DEBUG (dbgs () << " Abort because user used as tied operand\n " );
1241- return VLMAX ;
1239+ return std:: nullopt ;
12421240 }
12431241
1244- const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
1245-
1246- // The minimum demanded VL is the largest VL read amongst all the users. If
1247- // we cannot determine this statically, then we cannot optimize the VL.
1248- if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
1249- DemandedVL = VLOp;
1250- LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
1251- } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL)) {
1252- LLVM_DEBUG (
1253- dbgs () << " Abort because cannot determine the demanded VL\n " );
1254- return VLMAX;
1242+ auto VLOp = getMinimumVLForUser (UserOp);
1243+ if (!VLOp)
1244+ return std::nullopt ;
1245+
1246+ // Use the largest VL among all the users. If we cannot determine this
1247+ // statically, then we cannot optimize the VL.
1248+ if (!CommonVL || RISCV::isVLKnownLE (*CommonVL, *VLOp)) {
1249+ CommonVL = *VLOp;
1250+ LLVM_DEBUG (dbgs () << " User VL is: " << VLOp << " \n " );
1251+ } else if (!RISCV::isVLKnownLE (*VLOp, *CommonVL)) {
1252+ LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1253+ return std::nullopt ;
12551254 }
12561255
12571256 if (!RISCVII::hasSEWOp (UserMI.getDesc ().TSFlags )) {
12581257 LLVM_DEBUG (dbgs () << " Abort due to lack of SEW operand\n " );
1259- return VLMAX ;
1258+ return std:: nullopt ;
12601259 }
12611260
12621261 std::optional<OperandInfo> ConsumerInfo = getOperandInfo (UserOp, MRI);
@@ -1266,7 +1265,7 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
12661265 LLVM_DEBUG (dbgs () << " Abort due to unknown operand information.\n " );
12671266 LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
12681267 LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1269- return VLMAX ;
1268+ return std:: nullopt ;
12701269 }
12711270
12721271 // If the operand is used as a scalar operand, then the EEW must be
@@ -1281,11 +1280,11 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
12811280 << " Abort due to incompatible information for EMUL or EEW.\n " );
12821281 LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
12831282 LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1284- return VLMAX ;
1283+ return std:: nullopt ;
12851284 }
12861285 }
12871286
1288- return DemandedVL ;
1287+ return CommonVL ;
12891288}
12901289
12911290bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
@@ -1305,37 +1304,36 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
13051304 if (!CommonVL)
13061305 return false ;
13071306
1308- assert ((DemandedVL. isImm () || DemandedVL. getReg ().isVirtual ()) &&
1307+ assert ((CommonVL-> isImm () || CommonVL-> getReg ().isVirtual ()) &&
13091308 " Expected VL to be an Imm or virtual Reg" );
13101309
13111310 if (!RISCV::isVLKnownLE (*CommonVL, VLOp)) {
13121311 LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
13131312 return false ;
13141313 }
13151314
1316- if (DemandedVL. isIdenticalTo (VLOp)) {
1315+ if (CommonVL-> isIdenticalTo (VLOp)) {
13171316 LLVM_DEBUG (
1318- dbgs ()
1319- << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
1317+ dbgs () << " Abort due to CommonVL == VLOp, no point in reducing.\n " );
13201318 return false ;
13211319 }
13221320
1323- if (DemandedVL. isImm ()) {
1321+ if (CommonVL-> isImm ()) {
13241322 LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1325- << DemandedVL. getImm () << " for " << MI << " \n " );
1326- VLOp.ChangeToImmediate (DemandedVL. getImm ());
1323+ << CommonVL-> getImm () << " for " << MI << " \n " );
1324+ VLOp.ChangeToImmediate (CommonVL-> getImm ());
13271325 return true ;
13281326 }
1329- const MachineInstr *VLMI = MRI->getVRegDef (DemandedVL. getReg ());
1327+ const MachineInstr *VLMI = MRI->getVRegDef (CommonVL-> getReg ());
13301328 if (!MDT->dominates (VLMI, &MI))
13311329 return false ;
13321330 LLVM_DEBUG (
13331331 dbgs () << " Reduce VL from " << VLOp << " to "
1334- << printReg (DemandedVL. getReg (), MRI->getTargetRegisterInfo ())
1332+ << printReg (CommonVL-> getReg (), MRI->getTargetRegisterInfo ())
13351333 << " for " << MI << " \n " );
13361334
13371335 // All our checks passed. We can reduce VL.
1338- VLOp.ChangeToRegister (DemandedVL. getReg (), false );
1336+ VLOp.ChangeToRegister (CommonVL-> getReg (), false );
13391337 return true ;
13401338}
13411339
@@ -1360,10 +1358,11 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
13601358
13611359 // For each instruction that defines a vector, compute what VL its
13621360 // downstream users demand.
1363- for (const auto &MI : reverse (MBB)) {
1361+ for (MachineInstr &MI : reverse (MBB)) {
13641362 if (!isCandidate (MI))
13651363 continue ;
1366- DemandedVLs.insert ({&MI, computeDemandedVL (MI)});
1364+ if (auto DemandedVL = checkUsers (MI))
1365+ DemandedVLs.insert ({&MI, *DemandedVL});
13671366 }
13681367
13691368 // Then go through and see if we can reduce the VL of any instructions to
0 commit comments