diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index b818f4768c2c3..9c453f51e129d 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1659,17 +1659,20 @@ class LLVM_ABI TargetLoweringBase { /// InputVT should be treated. Either it's legal, needs to be promoted to a /// larger size, needs to be expanded to some other code sequence, or the /// target has a custom expander for it. - LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const { - PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy, - InputVT.getSimpleVT().SimpleTy}; - auto It = PartialReduceMLAActions.find(TypePair); + LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT, + EVT InputVT) const { + assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA); + PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy, + InputVT.getSimpleVT().SimpleTy}; + auto It = PartialReduceMLAActions.find(Key); return It != PartialReduceMLAActions.end() ? It->second : Expand; } /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is /// legal or custom for this target. - bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const { - LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT); + bool isPartialReduceMLALegalOrCustom(unsigned Opc, EVT AccVT, + EVT InputVT) const { + LegalizeAction Action = getPartialReduceMLAAction(Opc, AccVT, InputVT); return Action == Legal || Action == Custom; } @@ -2754,12 +2757,18 @@ class LLVM_ABI TargetLoweringBase { /// type InputVT should be treated by the target. Either it's legal, needs to /// be promoted to a larger size, needs to be expanded to some other code /// sequence, or the target has a custom expander for it. - void setPartialReduceMLAAction(MVT AccVT, MVT InputVT, + void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT, LegalizeAction Action) { + assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA); assert(AccVT.isValid() && InputVT.isValid() && "setPartialReduceMLAAction types aren't valid"); - PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy}; - PartialReduceMLAActions[TypePair] = Action; + PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy}; + PartialReduceMLAActions[Key] = Action; + } + void setPartialReduceMLAAction(ArrayRef Opcodes, MVT AccVT, + MVT InputVT, LegalizeAction Action) { + for (unsigned Opc : Opcodes) + setPartialReduceMLAAction(Opc, AccVT, InputVT, Action); } /// If Opc/OrigVT is specified as being promoted, the promotion code defaults @@ -3751,10 +3760,10 @@ class LLVM_ABI TargetLoweringBase { uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8]; using PartialReduceActionTypes = - std::pair; - /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA - /// nodes, keep a LegalizeAction which indicates how instruction selection - /// should deal with this operation. + std::tuple; + /// For each partial reduce opcode, result type and input type combination, + /// keep a LegalizeAction which indicates how instruction selection should + /// deal with this operation. DenseMap PartialReduceMLAActions; ValueTypeActionImpl ValueTypeActions; diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index e05f85ea3bd8e..be2209a2f8faf 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12673,17 +12673,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDValue LHSExtOp = LHS->getOperand(0); EVT LHSExtOpVT = LHSExtOp.getValueType(); + bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND; + unsigned NewOpcode = + ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; + // Only perform these combines if the target supports folding // the extends into the operation. if (!TLI.isPartialReduceMLALegalOrCustom( - TLI.getTypeToTransformTo(*Context, N->getValueType(0)), + NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)), TLI.getTypeToTransformTo(*Context, LHSExtOpVT))) return SDValue(); - bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND; - unsigned NewOpcode = - ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; - // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) // -> partial_reduce_*mla(acc, x, C) if (ISD::isConstantSplatVector(RHS.getNode(), C)) { @@ -12737,14 +12737,6 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { if (!ISD::isExtOpcode(Op1Opcode)) return SDValue(); - SDValue UnextOp1 = Op1.getOperand(0); - EVT UnextOp1VT = UnextOp1.getValueType(); - auto *Context = DAG.getContext(); - if (!TLI.isPartialReduceMLALegalOrCustom( - TLI.getTypeToTransformTo(*Context, N->getValueType(0)), - TLI.getTypeToTransformTo(*Context, UnextOp1VT))) - return SDValue(); - bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND; bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA; EVT AccElemVT = Acc.getValueType().getVectorElementType(); @@ -12754,6 +12746,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { unsigned NewOpcode = Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; + + SDValue UnextOp1 = Op1.getOperand(0); + EVT UnextOp1VT = UnextOp1.getValueType(); + auto *Context = DAG.getContext(); + if (!TLI.isPartialReduceMLALegalOrCustom( + NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)), + TLI.getTypeToTransformTo(*Context, UnextOp1VT))) + return SDValue(); + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1, DAG.getConstant(1, DL, UnextOp1VT)); } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index affcd78ea61b0..910a40e5b5141 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -530,8 +530,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { } case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: - Action = TLI.getPartialReduceMLAAction(Node->getValueType(0), - Node->getOperand(1).getValueType()); + Action = + TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0), + Node->getOperand(1).getValueType()); break; #define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \ diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a07afea963e20..f18d325148742 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1458,9 +1458,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FADD, VT, Custom); if (EnablePartialReduceNodes && Subtarget->hasDotProd()) { - setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal); - setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal); - setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom); + static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA, + ISD::PARTIAL_REDUCE_UMLA}; + + setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal); + setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal); + setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom); } } else /* !isNeonAvailable */ { @@ -1881,16 +1884,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) { // Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT). // Other pairs will default to 'Expand'. - setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal); - setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal); + static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA, + ISD::PARTIAL_REDUCE_UMLA}; + setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv8i16, Legal); + setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Legal); - setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom); + setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom); // Wide add types if (Subtarget->hasSVE2() || Subtarget->hasSME()) { - setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal); - setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal); - setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal); + setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal); + setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal); + setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal); } } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 43c81b97a0e05..567f4c5b47d30 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1573,11 +1573,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // zve32x is broken for partial_reduce_umla, but let's not make it worse. if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) { - setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom); - setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom); - setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom); - setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom); - setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom); + static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA, + ISD::PARTIAL_REDUCE_UMLA}; + setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom); + setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom); + setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom); + setPartialReduceMLAAction(MLAOps, MVT::nxv8i32, MVT::nxv32i8, Custom); + setPartialReduceMLAAction(MLAOps, MVT::nxv16i32, MVT::nxv64i8, Custom); if (Subtarget.useRVVForFixedLengthVectors()) { for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) { @@ -1586,7 +1588,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, continue; ElementCount EC = VT.getVectorElementCount(); MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4)); - setPartialReduceMLAAction(VT, ArgVT, Custom); + setPartialReduceMLAAction(MLAOps, VT, ArgVT, Custom); } } }