@@ -52,7 +52,9 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5252
5353private:
5454 MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55- // / Computes the VL of \p MI that is actually used by its users.
55+ // / Computes the minimum demanded VL of \p MI, i.e. the minimum VL that's used
56+ // / by its users downstream.
57+ // / Returns 0 if MI has no users.
5658 MachineOperand computeDemandedVL (const MachineInstr &MI);
5759 bool tryReduceVL (MachineInstr &MI);
5860 bool isCandidate (const MachineInstr &MI) const ;
@@ -1208,13 +1210,10 @@ MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12081210 // If we know the demanded VL of UserMI, then we can reduce the VL it
12091211 // requires.
12101212 if (DemandedVLs.contains (&UserMI)) {
1211- // We can only shrink the demanded VL if the elementwise result doesn't
1212- // depend on VL (i.e. not vredsum/viota etc.)
1213- // Also conservatively restrict to supported instructions for now.
1214- // TODO: Can we remove the isSupportedInstr check?
1213+ // We can only shrink the VL used if the elementwise result doesn't depend
1214+ // on VL (i.e. not vredsum/viota etc.)
12151215 if (!RISCVII::elementsDependOnVL (
1216- TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags ) &&
1217- isSupportedInstr (UserMI)) {
1216+ TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags )) {
12181217 const MachineOperand &DemandedVL = DemandedVLs.at (&UserMI);
12191218 if (RISCV::isVLKnownLE (DemandedVL, VLOp))
12201219 return DemandedVL;
@@ -1244,13 +1243,14 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
12441243
12451244 const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
12461245
1247- // Use the largest VL among all the users. If we cannot determine this
1248- // statically, then we cannot optimize the VL.
1246+ // The minimum demanded VL is the largest VL read amongst all the users. If
1247+ // we cannot determine this statically, then we cannot optimize the VL.
12491248 if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
12501249 DemandedVL = VLOp;
12511250 LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
12521251 } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL)) {
1253- LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1252+ LLVM_DEBUG (
1253+ dbgs () << " Abort because cannot determine the demanded VL\n " );
12541254 return VLMAX;
12551255 }
12561256
@@ -1291,41 +1291,42 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
12911291bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
12921292 LLVM_DEBUG (dbgs () << " Trying to reduce VL for " << MI << " \n " );
12931293
1294- const MachineOperand &CommonVL = DemandedVLs.at (&MI);
1294+ const MachineOperand &DemandedVL = DemandedVLs.at (&MI);
12951295
1296- assert ((CommonVL .isImm () || CommonVL .getReg ().isVirtual ()) &&
1296+ assert ((DemandedVL .isImm () || DemandedVL .getReg ().isVirtual ()) &&
12971297 " Expected VL to be an Imm or virtual Reg" );
12981298
12991299 unsigned VLOpNum = RISCVII::getVLOpNum (MI.getDesc ());
13001300 MachineOperand &VLOp = MI.getOperand (VLOpNum);
13011301
1302- if (!RISCV::isVLKnownLE (CommonVL , VLOp)) {
1302+ if (!RISCV::isVLKnownLE (DemandedVL , VLOp)) {
13031303 LLVM_DEBUG (dbgs () << " Abort due to DemandedVL not <= VLOp.\n " );
13041304 return false ;
13051305 }
13061306
1307- if (CommonVL .isIdenticalTo (VLOp)) {
1307+ if (DemandedVL .isIdenticalTo (VLOp)) {
13081308 LLVM_DEBUG (
13091309 dbgs ()
13101310 << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
13111311 return false ;
13121312 }
13131313
1314- if (CommonVL .isImm ()) {
1314+ if (DemandedVL .isImm ()) {
13151315 LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1316- << CommonVL .getImm () << " for " << MI << " \n " );
1317- VLOp.ChangeToImmediate (CommonVL .getImm ());
1316+ << DemandedVL .getImm () << " for " << MI << " \n " );
1317+ VLOp.ChangeToImmediate (DemandedVL .getImm ());
13181318 return true ;
13191319 }
1320- const MachineInstr *VLMI = MRI->getVRegDef (CommonVL .getReg ());
1320+ const MachineInstr *VLMI = MRI->getVRegDef (DemandedVL .getReg ());
13211321 if (!MDT->dominates (VLMI, &MI))
13221322 return false ;
1323- LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1324- << printReg (CommonVL.getReg (), MRI->getTargetRegisterInfo ())
1325- << " for " << MI << " \n " );
1323+ LLVM_DEBUG (
1324+ dbgs () << " Reduce VL from " << VLOp << " to "
1325+ << printReg (DemandedVL.getReg (), MRI->getTargetRegisterInfo ())
1326+ << " for " << MI << " \n " );
13261327
13271328 // All our checks passed. We can reduce VL.
1328- VLOp.ChangeToRegister (CommonVL .getReg (), false );
1329+ VLOp.ChangeToRegister (DemandedVL .getReg (), false );
13291330 return true ;
13301331}
13311332
0 commit comments