diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 850d6244affa5..66be30c4b42ba 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1548,7 +1548,7 @@ RISCVTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind) + getCmpSelInstrCost(Instruction::ICmp, ElementTy, ElementTy, CmpInst::ICMP_EQ, CostKind); - } else if (ISD == ISD::XOR) { + } else if (ISD == ISD::XOR || ISD == ISD::ADD) { // Example sequences: // vsetvli a0, zero, e8, mf8, ta, ma // vmxor.mm v8, v0, v8 ; needed every time type is split @@ -1558,13 +1558,14 @@ RISCVTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, getRISCVInstructionCost(RISCV::VMXOR_MM, LT.second, CostKind) + getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind) + 1; } else { + assert(ISD == ISD::OR); // Example sequences: // vsetvli a0, zero, e8, mf8, ta, ma - // vmxor.mm v8, v9, v8 ; needed every time type is split + // vmor.mm v8, v9, v8 ; needed every time type is split // vcpop.m a0, v0 // snez a0, a0 return (LT.first - 1) * - getRISCVInstructionCost(RISCV::VMXOR_MM, LT.second, CostKind) + + getRISCVInstructionCost(RISCV::VMOR_MM, LT.second, CostKind) + getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind) + getCmpSelInstrCost(Instruction::ICmp, ElementTy, ElementTy, CmpInst::ICMP_NE, CostKind);