diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index 9c910c6b59438..96a73d9720a43 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -168,24 +168,22 @@ getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { } // end namespace RISCVVType } // end namespace llvm -/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2). -/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI. -static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor, - const MachineInstr &MI, - const MachineOperand &MO) { - RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); +/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2). +/// SEW comes from TSFlags of MI. +static unsigned getIntegerExtensionOperandEEW(unsigned Factor, + const MachineInstr &MI, + const MachineOperand &MO) { unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); + return MILog2SEW; unsigned MISEW = 1 << MILog2SEW; unsigned EEW = MISEW / Factor; unsigned Log2EEW = Log2_32(EEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), - Log2EEW); + return Log2EEW; } /// Check whether MO is a mask operand of MI. @@ -199,18 +197,15 @@ static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; } -/// Return the OperandInfo for MO. -static OperandInfo getOperandInfo(const MachineOperand &MO, - const MachineRegisterInfo *MRI) { +static std::optional +getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { const MachineInstr &MI = *MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); assert(RVV && "Could not find MI in PseudoTable"); - // MI has a VLMUL and SEW associated with it. The RVV specification defines - // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and - // MI.SEW. - RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); + // MI has a SEW associated with it. The RVV specification defines + // the EEW of each operand and definition in relation to MI.SEW. unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); @@ -221,13 +216,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // since they must preserve the entire register content. if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() && (MO.getReg() != RISCV::NoRegister)) - return {}; + return std::nullopt; bool IsMODef = MO.getOperandNo() == 0; - // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL + // All mask operands have EEW=1 if (isMaskOperand(MI, MO, MRI)) - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; // switch against BaseInstr to reduce number of cases that need to be // considered. @@ -244,66 +239,65 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Loads and Stores // Vector Unit-Stride Instructions // Vector Strided Instructions - /// Dest EEW encoded in the instruction and EMUL=(EEW/SEW)*LMUL + /// Dest EEW encoded in the instruction case RISCV::VLM_V: case RISCV::VSM_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; case RISCV::VLE8_V: case RISCV::VSE8_V: case RISCV::VLSE8_V: case RISCV::VSSE8_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3); + return 3; case RISCV::VLE16_V: case RISCV::VSE16_V: case RISCV::VLSE16_V: case RISCV::VSSE16_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4); + return 4; case RISCV::VLE32_V: case RISCV::VSE32_V: case RISCV::VLSE32_V: case RISCV::VSSE32_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5); + return 5; case RISCV::VLE64_V: case RISCV::VSE64_V: case RISCV::VLSE64_V: case RISCV::VSSE64_V: - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6); + return 6; // Vector Indexed Instructions // vs(o|u)xei.v - // Dest/Data (operand 0) EEW=SEW, EMUL=LMUL. Source EEW= and - // EMUL=(EEW/SEW)*LMUL. + // Dest/Data (operand 0) EEW=SEW. Source EEW=. case RISCV::VLUXEI8_V: case RISCV::VLOXEI8_V: case RISCV::VSUXEI8_V: case RISCV::VSOXEI8_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(3, MI), 3); + return MILog2SEW; + return 3; } case RISCV::VLUXEI16_V: case RISCV::VLOXEI16_V: case RISCV::VSUXEI16_V: case RISCV::VSOXEI16_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(4, MI), 4); + return MILog2SEW; + return 4; } case RISCV::VLUXEI32_V: case RISCV::VLOXEI32_V: case RISCV::VSUXEI32_V: case RISCV::VSOXEI32_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(5, MI), 5); + return MILog2SEW; + return 5; } case RISCV::VLUXEI64_V: case RISCV::VLOXEI64_V: case RISCV::VSUXEI64_V: case RISCV::VSOXEI64_V: { if (MO.getOperandNo() == 0) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(6, MI), 6); + return MILog2SEW; + return 6; } // Vector Integer Arithmetic Instructions @@ -317,7 +311,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VRSUB_VX: // Vector Bitwise Logical Instructions // Vector Single-Width Shift Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VAND_VI: case RISCV::VAND_VV: case RISCV::VAND_VX: @@ -337,7 +331,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VSRA_VV: case RISCV::VSRA_VX: // Vector Integer Min/Max Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMINU_VV: case RISCV::VMINU_VX: case RISCV::VMIN_VV: @@ -347,7 +341,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMAX_VV: case RISCV::VMAX_VX: // Vector Single-Width Integer Multiply Instructions - // Source and Dest EEW=SEW and EMUL=LMUL. + // Source and Dest EEW=SEW. case RISCV::VMUL_VV: case RISCV::VMUL_VX: case RISCV::VMULH_VV: @@ -357,7 +351,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMULHSU_VV: case RISCV::VMULHSU_VX: // Vector Integer Divide Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VDIVU_VV: case RISCV::VDIVU_VX: case RISCV::VDIV_VV: @@ -367,7 +361,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VREM_VV: case RISCV::VREM_VX: // Vector Single-Width Integer Multiply-Add Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMACC_VV: case RISCV::VMACC_VX: case RISCV::VNMSAC_VV: @@ -378,8 +372,8 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VNMSUB_VX: // Vector Integer Merge Instructions // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= - // (EEW/SEW)*LMUL. Mask operand is handled before this switch. + // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled + // before this switch. case RISCV::VMERGE_VIM: case RISCV::VMERGE_VVM: case RISCV::VMERGE_VXM: @@ -392,7 +386,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Fixed-Point Arithmetic Instructions // Vector Single-Width Saturating Add and Subtract // Vector Single-Width Averaging Add and Subtract - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VMV_V_I: case RISCV::VMV_V_V: case RISCV::VMV_V_X: @@ -415,12 +409,12 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VASUB_VV: case RISCV::VASUB_VX: // Vector Single-Width Fractional Multiply with Rounding and Saturation - // EEW=SEW. EMUL=LMUL. The instruction produces 2*SEW product internally but + // EEW=SEW. The instruction produces 2*SEW product internally but // saturates to fit into SEW bits. case RISCV::VSMUL_VV: case RISCV::VSMUL_VX: // Vector Single-Width Scaling Shift Instructions - // EEW=SEW. EMUL=LMUL. + // EEW=SEW. case RISCV::VSSRL_VI: case RISCV::VSSRL_VV: case RISCV::VSSRL_VX: @@ -430,13 +424,13 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // Vector Permutation Instructions // Integer Scalar Move Instructions // Floating-Point Scalar Move Instructions - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VMV_X_S: case RISCV::VMV_S_X: case RISCV::VFMV_F_S: case RISCV::VFMV_S_F: // Vector Slide Instructions - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VSLIDEUP_VI: case RISCV::VSLIDEUP_VX: case RISCV::VSLIDEDOWN_VI: @@ -446,12 +440,12 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VSLIDE1DOWN_VX: case RISCV::VFSLIDE1DOWN_VF: // Vector Register Gather Instructions - // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1. + // EEW=SEW. For mask operand, EEW=1. case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: // Vector Compress Instruction - // EMUL=LMUL. EEW=SEW. + // EEW=SEW. case RISCV::VCOMPRESS_VM: // Vector Element Index Instruction case RISCV::VID_V: @@ -498,10 +492,10 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VFCVT_F_X_V: // Vector Floating-Point Merge Instruction case RISCV::VFMERGE_VFM: - return OperandInfo(MIVLMul, MILog2SEW); + return MILog2SEW; // Vector Widening Integer Add/Subtract - // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL. + // Def uses EEW=2*SEW . Operands use EEW=SEW. case RISCV::VWADDU_VV: case RISCV::VWADDU_VX: case RISCV::VWSUBU_VV: @@ -512,7 +506,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWSUB_VX: case RISCV::VWSLL_VI: // Vector Widening Integer Multiply Instructions - // Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW. + // Destination EEW=2*SEW. Source EEW=SEW. case RISCV::VWMUL_VV: case RISCV::VWMUL_VX: case RISCV::VWMULSU_VV: @@ -520,7 +514,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VWMULU_VV: case RISCV::VWMULU_VX: // Vector Widening Integer Multiply-Add Instructions - // Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL. + // Destination EEW=2*SEW. Source EEW=SEW. // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which // is then added to the 2*SEW-bit Dest. These instructions never have a // passthru operand. @@ -541,7 +535,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VFWNMSAC_VF: case RISCV::VFWNMSAC_VV: // Vector Widening Floating-Point Add/Subtract Instructions - // Dest EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=2*SEW. Source EEW=SEW. case RISCV::VFWADD_VV: case RISCV::VFWADD_VF: case RISCV::VFWSUB_VV: @@ -558,11 +552,10 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VFWCVT_F_X_V: case RISCV::VFWCVT_F_F_V: { unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW; - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), - Log2EEW); + return Log2EEW; } - // Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL + // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW. case RISCV::VWADDU_WV: case RISCV::VWADDU_WX: case RISCV::VWSUBU_WV: @@ -579,24 +572,22 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsMODef || IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), - Log2EEW); + return Log2EEW; } // Vector Integer Extension case RISCV::VZEXT_VF2: case RISCV::VSEXT_VF2: - return getIntegerExtensionOperandInfo(2, MI, MO); + return getIntegerExtensionOperandEEW(2, MI, MO); case RISCV::VZEXT_VF4: case RISCV::VSEXT_VF4: - return getIntegerExtensionOperandInfo(4, MI, MO); + return getIntegerExtensionOperandEEW(4, MI, MO); case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: - return getIntegerExtensionOperandInfo(8, MI, MO); + return getIntegerExtensionOperandEEW(8, MI, MO); // Vector Narrowing Integer Right Shift Instructions - // Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has - // EEW=SEW EMUL=LMUL. + // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW case RISCV::VNSRL_WX: case RISCV::VNSRL_WI: case RISCV::VNSRL_WV: @@ -604,7 +595,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VNSRA_WV: case RISCV::VNSRA_WX: // Vector Narrowing Fixed-Point Clip Instructions - // Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL + // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW. case RISCV::VNCLIPU_WI: case RISCV::VNCLIPU_WV: case RISCV::VNCLIPU_WX: @@ -623,8 +614,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; bool TwoTimes = IsOp1; unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), - Log2EEW); + return Log2EEW; } // Vector Mask Instructions @@ -632,7 +622,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, // vmsbf.m set-before-first mask bit // vmsif.m set-including-first mask bit // vmsof.m set-only-first mask bit - // EEW=1 and EMUL=(EEW/SEW)*LMUL + // EEW=1 // We handle the cases when operand is a v0 mask operand above the switch, // but these instructions may use non-v0 mask operands and need to be handled // specifically. @@ -647,20 +637,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMSBF_M: case RISCV::VMSIF_M: case RISCV::VMSOF_M: { - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return 0; } // Vector Iota Instruction - // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= - // (EEW/SEW)*LMUL. Mask operand is not handled before this switch. + // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled + // before this switch. case RISCV::VIOTA_M: { if (IsMODef || MO.getOperandNo() == 1) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); + return MILog2SEW; + return 0; } // Vector Integer Compare Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW. case RISCV::VMSEQ_VI: case RISCV::VMSEQ_VV: case RISCV::VMSEQ_VX: @@ -682,21 +672,20 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMSGT_VI: case RISCV::VMSGT_VX: // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. Mask - // source operand handled above this switch. + // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch. case RISCV::VMADC_VIM: case RISCV::VMADC_VVM: case RISCV::VMADC_VXM: case RISCV::VMSBC_VVM: case RISCV::VMSBC_VXM: - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW and EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW. case RISCV::VMADC_VV: case RISCV::VMADC_VI: case RISCV::VMADC_VX: case RISCV::VMSBC_VV: case RISCV::VMSBC_VX: // 13.13. Vector Floating-Point Compare Instructions - // Dest EEW=1 and EMUL=(EEW/SEW)*LMUL. Source EEW=SEW EMUL=LMUL. + // Dest EEW=1. Source EEW=SEW case RISCV::VMFEQ_VF: case RISCV::VMFEQ_VV: case RISCV::VMFNE_VF: @@ -708,14 +697,12 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VMFGT_VF: case RISCV::VMFGE_VF: { if (IsMODef) - return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); - return OperandInfo(MIVLMul, MILog2SEW); + return 0; + return MILog2SEW; } // Vector Reduction Operations // Vector Single-Width Integer Reduction Instructions - // The Dest and VS1 only read element 0 of the vector register. Return just - // the EEW for these. VS2 has EEW=SEW and EMUL=LMUL. case RISCV::VREDAND_VS: case RISCV::VREDMAX_VS: case RISCV::VREDMAXU_VS: @@ -724,9 +711,7 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, case RISCV::VREDOR_VS: case RISCV::VREDSUM_VS: case RISCV::VREDXOR_VS: { - if (MO.getOperandNo() == 2) - return OperandInfo(MIVLMul, MILog2SEW); - return OperandInfo(MILog2SEW); + return MILog2SEW; } default: @@ -734,6 +719,40 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, } } +static OperandInfo getOperandInfo(const MachineOperand &MO, + const MachineRegisterInfo *MRI) { + const MachineInstr &MI = *MO.getParent(); + const RISCVVPseudosTable::PseudoInfo *RVV = + RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); + assert(RVV && "Could not find MI in PseudoTable"); + + std::optional Log2EEW = getOperandLog2EEW(MO, MRI); + if (!Log2EEW) + return {}; + + switch (RVV->BaseInstr) { + // Vector Reduction Operations + // Vector Single-Width Integer Reduction Instructions + // The Dest and VS1 only read element 0 of the vector register. Return just + // the EEW for these. + case RISCV::VREDAND_VS: + case RISCV::VREDMAX_VS: + case RISCV::VREDMAXU_VS: + case RISCV::VREDMIN_VS: + case RISCV::VREDMINU_VS: + case RISCV::VREDOR_VS: + case RISCV::VREDSUM_VS: + case RISCV::VREDXOR_VS: + if (MO.getOperandNo() != 2) + return OperandInfo(*Log2EEW); + break; + }; + + // All others have EMUL=EEW/SEW*LMUL + return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), + *Log2EEW); +} + /// Return true if this optimization should consider MI for VL reduction. This /// white-list approach simplifies this optimization for instructions that may /// have more complex semantics with relation to how it uses VL.