diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index b8539a5d1add1..3989a966edfd3 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -4102,3 +4102,17 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) { assert(Scaled >= 3 && Scaled <= 6); return Scaled; } + +/// Given two VL operands, do we know that LHS <= RHS? +bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) { + if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() && + LHS.getReg() == RHS.getReg()) + return true; + if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel) + return true; + if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel) + return false; + if (!LHS.isImm() || !RHS.isImm()) + return false; + return LHS.getImm() <= RHS.getImm(); +} diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index 457db9b9860d0..c3aa367486627 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -346,6 +346,9 @@ unsigned getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW); // Special immediate for AVL operand of V pseudo instructions to indicate VLMax. static constexpr int64_t VLMaxSentinel = -1LL; +/// Given two VL operands, do we know that LHS <= RHS? +bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS); + // Mask assignments for floating-point static constexpr unsigned FPMASK_Negative_Infinity = 0x001; static constexpr unsigned FPMASK_Negative_Normal = 0x002; diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index 6053899987db9..ee494c4681511 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -51,7 +51,7 @@ class RISCVVLOptimizer : public MachineFunctionPass { StringRef getPassName() const override { return PASS_NAME; } private: - bool checkUsers(std::optional &CommonVL, MachineInstr &MI); + bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI); bool tryReduceVL(MachineInstr &MI); bool isCandidate(const MachineInstr &MI) const; }; @@ -658,10 +658,34 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { if (MI.getNumDefs() != 1) return false; + // If we're not using VLMAX, then we need to be careful whether we are using + // TA/TU when there is a non-undef Passthru. But when we are using VLMAX, it + // does not matter whether we are using TA/TU with a non-undef Passthru, since + // there are no tail elements to be perserved. unsigned VLOpNum = RISCVII::getVLOpNum(Desc); const MachineOperand &VLOp = MI.getOperand(VLOpNum); - if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel) + if (VLOp.isReg() || VLOp.getImm() != RISCV::VLMaxSentinel) { + // If MI has a non-undef passthru, we will not try to optimize it since + // that requires us to preserve tail elements according to TA/TU. + // Otherwise, The MI has an undef Passthru, so it doesn't matter whether we + // are using TA/TU. + bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc); + unsigned PassthruOpIdx = MI.getNumExplicitDefs(); + if (HasPassthru && + MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) { + LLVM_DEBUG( + dbgs() << " Not a candidate because it uses non-undef passthru" + " with non-VLMAX VL\n"); + return false; + } + } + + // If the VL is 1, then there is no need to reduce it. This is an + // optimization, not needed to preserve correctness. + if (VLOp.isImm() && VLOp.getImm() == 1) { + LLVM_DEBUG(dbgs() << " Not a candidate because VL is already 1\n"); return false; + } // Some instructions that produce vectors have semantics that make it more // difficult to determine whether the VL can be reduced. For example, some @@ -684,7 +708,7 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return true; } -bool RISCVVLOptimizer::checkUsers(std::optional &CommonVL, +bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI) { // FIXME: Avoid visiting each user for each time we visit something on the // worklist, combined with an extra visit from the outer loop. Restructure @@ -730,16 +754,17 @@ bool RISCVVLOptimizer::checkUsers(std::optional &CommonVL, unsigned VLOpNum = RISCVII::getVLOpNum(Desc); const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); - // Looking for a register VL that isn't X0. - if (!VLOp.isReg() || VLOp.getReg() == RISCV::X0) { - LLVM_DEBUG(dbgs() << " Abort due to user uses X0 as VL.\n"); - CanReduceVL = false; - break; - } + + // Looking for an immediate or a register VL that isn't X0. + assert(!VLOp.isReg() || + VLOp.getReg() != RISCV::X0 && "Did not expect X0 VL"); if (!CommonVL) { - CommonVL = VLOp.getReg(); - } else if (*CommonVL != VLOp.getReg()) { + CommonVL = &VLOp; + LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); + } else if (!CommonVL->isIdenticalTo(VLOp)) { + // FIXME: This check requires all users to have the same VL. We can relax + // this and get the largest VL amongst all users. LLVM_DEBUG(dbgs() << " Abort because users have different VL\n"); CanReduceVL = false; break; @@ -776,7 +801,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { MachineInstr &MI = *Worklist.pop_back_val(); LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); - std::optional CommonVL; + const MachineOperand *CommonVL = nullptr; bool CanReduceVL = true; if (isVectorRegClass(MI.getOperand(0).getReg(), MRI)) CanReduceVL = checkUsers(CommonVL, MI); @@ -784,21 +809,34 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { if (!CanReduceVL || !CommonVL) continue; - if (!CommonVL->isVirtual()) { - LLVM_DEBUG( - dbgs() << " Abort due to new VL is not virtual register.\n"); + assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && + "Expected VL to be an Imm or virtual Reg"); + + unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); + MachineOperand &VLOp = MI.getOperand(VLOpNum); + + if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { + LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); continue; } - const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL); - if (!MDT->dominates(VLMI, &MI)) - continue; + if (CommonVL->isImm()) { + LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to " + << CommonVL->getImm() << " for " << MI << "\n"); + VLOp.ChangeToImmediate(CommonVL->getImm()); + } else { + const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); + if (!MDT->dominates(VLMI, &MI)) + continue; + LLVM_DEBUG( + dbgs() << " Reduce VL from " << VLOp << " to " + << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo()) + << " for " << MI << "\n"); + + // All our checks passed. We can reduce VL. + VLOp.ChangeToRegister(CommonVL->getReg(), false); + } - // All our checks passed. We can reduce VL. - LLVM_DEBUG(dbgs() << " Reducing VL for: " << MI << "\n"); - unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); - MachineOperand &VLOp = MI.getOperand(VLOpNum); - VLOp.ChangeToRegister(*CommonVL, false); MadeChange = true; // Now add all inputs to this instruction to the worklist. diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp index b883c50beadc0..a57bc5a3007d0 100644 --- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp +++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp @@ -86,20 +86,6 @@ char RISCVVectorPeephole::ID = 0; INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false, false) -/// Given two VL operands, do we know that LHS <= RHS? -static bool isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) { - if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() && - LHS.getReg() == RHS.getReg()) - return true; - if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel) - return true; - if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel) - return false; - if (!LHS.isImm() || !RHS.isImm()) - return false; - return LHS.getImm() <= RHS.getImm(); -} - /// Given \p User that has an input operand with EEW=SEW, which uses the dest /// operand of \p Src with an unknown EEW, return true if their EEWs match. bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User, @@ -191,7 +177,7 @@ bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const { return false; MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc())); - if (VL.isIdenticalTo(SrcVL) || !isVLKnownLE(VL, SrcVL)) + if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL)) return false; if (!ensureDominates(VL, *Src)) @@ -580,7 +566,7 @@ bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) { MachineOperand &SrcPolicy = Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())); - if (isVLKnownLE(MIVL, SrcVL)) + if (RISCV::isVLKnownLE(MIVL, SrcVL)) SrcPolicy.setImm(SrcPolicy.getImm() | RISCVII::TAIL_AGNOSTIC); } @@ -631,7 +617,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) { // so we don't need to handle a smaller source VL here. However, the // user's VL may be larger MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc())); - if (!isVLKnownLE(SrcVL, MI.getOperand(3))) + if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3))) return false; // If the new passthru doesn't dominate Src, try to move Src so it does. @@ -650,7 +636,7 @@ bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) { // If MI was tail agnostic and the VL didn't increase, preserve it. int64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED; if ((MI.getOperand(5).getImm() & RISCVII::TAIL_AGNOSTIC) && - isVLKnownLE(MI.getOperand(3), SrcVL)) + RISCV::isVLKnownLE(MI.getOperand(3), SrcVL)) Policy |= RISCVII::TAIL_AGNOSTIC; Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy); diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll index 0b3e67ec89556..1a1472fcfc66f 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt.ll @@ -11,19 +11,46 @@ declare @llvm.riscv.vadd.nxv4i32.nxv4i32(, , , iXLen) define @different_imm_vl_with_ta( %passthru, %a, %b, iXLen %vl1, iXLen %vl2) { -; CHECK-LABEL: different_imm_vl_with_ta: -; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma -; CHECK-NEXT: vadd.vv v8, v10, v12 -; CHECK-NEXT: vsetivli zero, 4, e32, m2, ta, ma -; CHECK-NEXT: vadd.vv v8, v8, v10 -; CHECK-NEXT: ret +; NOVLOPT-LABEL: different_imm_vl_with_ta: +; NOVLOPT: # %bb.0: +; NOVLOPT-NEXT: vsetivli zero, 5, e32, m2, ta, ma +; NOVLOPT-NEXT: vadd.vv v8, v10, v12 +; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma +; NOVLOPT-NEXT: vadd.vv v8, v8, v10 +; NOVLOPT-NEXT: ret +; +; VLOPT-LABEL: different_imm_vl_with_ta: +; VLOPT: # %bb.0: +; VLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma +; VLOPT-NEXT: vadd.vv v8, v10, v12 +; VLOPT-NEXT: vadd.vv v8, v8, v10 +; VLOPT-NEXT: ret %v = call @llvm.riscv.vadd.nxv4i32.nxv4i32( poison, %a, %b, iXLen 5) %w = call @llvm.riscv.vadd.nxv4i32.nxv4i32( poison, %v, %a, iXLen 4) ret %w } -; No benificial to propagate VL since VL is larger in the use side. +define @vlmax_and_imm_vl_with_ta( %passthru, %a, %b, iXLen %vl1, iXLen %vl2) { +; NOVLOPT-LABEL: vlmax_and_imm_vl_with_ta: +; NOVLOPT: # %bb.0: +; NOVLOPT-NEXT: vsetvli a0, zero, e32, m2, ta, ma +; NOVLOPT-NEXT: vadd.vv v8, v10, v12 +; NOVLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma +; NOVLOPT-NEXT: vadd.vv v8, v8, v10 +; NOVLOPT-NEXT: ret +; +; VLOPT-LABEL: vlmax_and_imm_vl_with_ta: +; VLOPT: # %bb.0: +; VLOPT-NEXT: vsetivli zero, 4, e32, m2, ta, ma +; VLOPT-NEXT: vadd.vv v8, v10, v12 +; VLOPT-NEXT: vadd.vv v8, v8, v10 +; VLOPT-NEXT: ret + %v = call @llvm.riscv.vadd.nxv4i32.nxv4i32( poison, %a, %b, iXLen -1) + %w = call @llvm.riscv.vadd.nxv4i32.nxv4i32( poison, %v, %a, iXLen 4) + ret %w +} + +; Not beneficial to propagate VL since VL is larger in the use side. define @different_imm_vl_with_ta_larger_vl( %passthru, %a, %b, iXLen %vl1, iXLen %vl2) { ; CHECK-LABEL: different_imm_vl_with_ta_larger_vl: ; CHECK: # %bb.0: @@ -50,8 +77,7 @@ define @different_imm_reg_vl_with_ta( %pass ret %w } - -; No benificial to propagate VL since VL is already one. +; Not beneficial to propagate VL since VL is already one. define @different_imm_vl_with_ta_1( %passthru, %a, %b, iXLen %vl1, iXLen %vl2) { ; CHECK-LABEL: different_imm_vl_with_ta_1: ; CHECK: # %bb.0: @@ -110,7 +136,3 @@ define @different_imm_vl_with_tu( %passthru %w = call @llvm.riscv.vadd.nxv4i32.nxv4i32( %passthru, %v, %a,iXLen 4) ret %w } - -;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: -; NOVLOPT: {{.*}} -; VLOPT: {{.*}}