@@ -52,7 +52,9 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5252
5353private:
5454 MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55- // / Computes the VL of \p MI that is actually used by its users.
55+ // / Computes the minimum demanded VL of \p MI, i.e. the minimum VL that's used
56+ // / by its users downstream.
57+ // / Returns 0 if MI has no users.
5658 MachineOperand computeDemandedVL (const MachineInstr &MI);
5759 bool tryReduceVL (MachineInstr &MI);
5860 bool isCandidate (const MachineInstr &MI) const ;
@@ -1208,13 +1210,10 @@ MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12081210 // If we know the demanded VL of UserMI, then we can reduce the VL it
12091211 // requires.
12101212 if (DemandedVLs.contains (&UserMI)) {
1211- // We can only shrink the demanded VL if the elementwise result doesn't
1212- // depend on VL (i.e. not vredsum/viota etc.)
1213- // Also conservatively restrict to supported instructions for now.
1214- // TODO: Can we remove the isSupportedInstr check?
1213+ // We can only shrink the VL used if the elementwise result doesn't depend
1214+ // on VL (i.e. not vredsum/viota etc.)
12151215 if (!RISCVII::elementsDependOnVL (
1216- TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags ) &&
1217- isSupportedInstr (UserMI)) {
1216+ TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags )) {
12181217 const MachineOperand &DemandedVL = DemandedVLs.at (&UserMI);
12191218 if (RISCV::isVLKnownLE (DemandedVL, VLOp))
12201219 return DemandedVL;
@@ -1244,13 +1243,14 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
12441243
12451244 const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
12461245
1247- // Use the largest VL among all the users. If we cannot determine this
1248- // statically, then we cannot optimize the VL.
1246+ // The minimum demanded VL is the largest VL read amongst all the users. If
1247+ // we cannot determine this statically, then we cannot optimize the VL.
12491248 if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
12501249 DemandedVL = VLOp;
12511250 LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
12521251 } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL)) {
1253- LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1252+ LLVM_DEBUG (
1253+ dbgs () << " Abort because cannot determine the demanded VL\n " );
12541254 return VLMAX;
12551255 }
12561256
@@ -1305,36 +1305,37 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
13051305 if (!CommonVL)
13061306 return false ;
13071307
1308- assert ((CommonVL .isImm () || CommonVL .getReg ().isVirtual ()) &&
1308+ assert ((DemandedVL .isImm () || DemandedVL .getReg ().isVirtual ()) &&
13091309 " Expected VL to be an Imm or virtual Reg" );
13101310
13111311 if (!RISCV::isVLKnownLE (*CommonVL, VLOp)) {
13121312 LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
13131313 return false ;
13141314 }
13151315
1316- if (CommonVL .isIdenticalTo (VLOp)) {
1316+ if (DemandedVL .isIdenticalTo (VLOp)) {
13171317 LLVM_DEBUG (
13181318 dbgs ()
13191319 << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
13201320 return false ;
13211321 }
13221322
1323- if (CommonVL .isImm ()) {
1323+ if (DemandedVL .isImm ()) {
13241324 LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1325- << CommonVL .getImm () << " for " << MI << " \n " );
1326- VLOp.ChangeToImmediate (CommonVL .getImm ());
1325+ << DemandedVL .getImm () << " for " << MI << " \n " );
1326+ VLOp.ChangeToImmediate (DemandedVL .getImm ());
13271327 return true ;
13281328 }
1329- const MachineInstr *VLMI = MRI->getVRegDef (CommonVL .getReg ());
1329+ const MachineInstr *VLMI = MRI->getVRegDef (DemandedVL .getReg ());
13301330 if (!MDT->dominates (VLMI, &MI))
13311331 return false ;
1332- LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1333- << printReg (CommonVL.getReg (), MRI->getTargetRegisterInfo ())
1334- << " for " << MI << " \n " );
1332+ LLVM_DEBUG (
1333+ dbgs () << " Reduce VL from " << VLOp << " to "
1334+ << printReg (DemandedVL.getReg (), MRI->getTargetRegisterInfo ())
1335+ << " for " << MI << " \n " );
13351336
13361337 // All our checks passed. We can reduce VL.
1337- VLOp.ChangeToRegister (CommonVL .getReg (), false );
1338+ VLOp.ChangeToRegister (DemandedVL .getReg (), false );
13381339 return true ;
13391340}
13401341
0 commit comments