@@ -2042,6 +2042,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
20422042 case ISD::PARTIAL_REDUCE_SMLA:
20432043 case ISD::PARTIAL_REDUCE_UMLA:
20442044 case ISD::PARTIAL_REDUCE_SUMLA:
2045+ case ISD::PARTIAL_REDUCE_FMLA:
20452046 return visitPARTIAL_REDUCE_MLA(N);
20462047 case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
20472048 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -13006,6 +13007,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1300613007//
1300713008// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1300813009// -> partial_reduce_*mla(acc, x, C)
13010+ //
13011+ // partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
13012+ // -> partial_reduce_fmla(acc, a, b)
1300913013SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1301013014 SDLoc DL(N);
1301113015 auto *Context = DAG.getContext();
@@ -13014,7 +13018,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1301413018 SDValue Op2 = N->getOperand(2);
1301513019
1301613020 unsigned Opc = Op1->getOpcode();
13017- if (Opc != ISD::MUL && Opc != ISD::SHL)
13021+ if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD:: SHL)
1301813022 return SDValue();
1301913023
1302013024 SDValue LHS = Op1->getOperand(0);
@@ -13033,20 +13037,24 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1303313037 Opc = ISD::MUL;
1303413038 }
1303513039
13036- APInt C;
13037- if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
13038- !C.isOne())
13040+ if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) &&
13041+ !(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2)))
1303913042 return SDValue();
1304013043
13044+ auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
13045+ return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND);
13046+ };
13047+
1304113048 unsigned LHSOpcode = LHS->getOpcode();
13042- if (!ISD::isExtOpcode (LHSOpcode))
13049+ if (!IsIntOrFPExtOpcode (LHSOpcode))
1304313050 return SDValue();
1304413051
1304513052 SDValue LHSExtOp = LHS->getOperand(0);
1304613053 EVT LHSExtOpVT = LHSExtOp.getValueType();
1304713054
1304813055 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1304913056 // -> partial_reduce_*mla(acc, x, C)
13057+ APInt C;
1305013058 if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
1305113059 // TODO: Make use of partial_reduce_sumla here
1305213060 APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
@@ -13071,7 +13079,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1307113079 }
1307213080
1307313081 unsigned RHSOpcode = RHS->getOpcode();
13074- if (!ISD::isExtOpcode (RHSOpcode))
13082+ if (!IsIntOrFPExtOpcode (RHSOpcode))
1307513083 return SDValue();
1307613084
1307713085 SDValue RHSExtOp = RHS->getOperand(0);
@@ -13088,6 +13096,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1308813096 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
1308913097 NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
1309013098 std::swap(LHSExtOp, RHSExtOp);
13099+ } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
13100+ NewOpc = ISD::PARTIAL_REDUCE_FMLA;
1309113101 } else
1309213102 return SDValue();
1309313103 // For a 2-stage extend the signedness of both of the extends must match
@@ -13115,30 +13125,33 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1311513125// -> partial.reduce.smla(acc, op, splat(trunc(1)))
1311613126// partial.reduce.sumla(acc, sext(op), splat(1))
1311713127// -> partial.reduce.smla(acc, op, splat(trunc(1)))
13128+ // partial.reduce.fmla(acc, fpext(op), splat(1.0))
13129+ // -> partial.reduce.fmla(acc, op, splat(1.0))
1311813130SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1311913131 SDLoc DL(N);
1312013132 SDValue Acc = N->getOperand(0);
1312113133 SDValue Op1 = N->getOperand(1);
1312213134 SDValue Op2 = N->getOperand(2);
1312313135
13124- APInt ConstantOne;
13125- if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
13126- !ConstantOne.isOne())
13136+ if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
1312713137 return SDValue();
1312813138
1312913139 unsigned Op1Opcode = Op1.getOpcode();
13130- if (!ISD::isExtOpcode(Op1Opcode))
13140+ if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND )
1313113141 return SDValue();
1313213142
13133- bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
13143+ bool Op1IsSigned =
13144+ Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND;
1313413145 bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
1313513146 EVT AccElemVT = Acc.getValueType().getVectorElementType();
1313613147 if (Op1IsSigned != NodeIsSigned &&
1313713148 Op1.getValueType().getVectorElementType() != AccElemVT)
1313813149 return SDValue();
1313913150
13140- unsigned NewOpcode =
13141- Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
13151+ unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
13152+ ? ISD::PARTIAL_REDUCE_FMLA
13153+ : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
13154+ : ISD::PARTIAL_REDUCE_UMLA;
1314213155
1314313156 SDValue UnextOp1 = Op1.getOperand(0);
1314413157 EVT UnextOp1VT = UnextOp1.getValueType();
@@ -13148,8 +13161,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1314813161 TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
1314913162 return SDValue();
1315013163
13164+ SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
13165+ ? DAG.getConstantFP(1, DL, UnextOp1VT)
13166+ : DAG.getConstant(1, DL, UnextOp1VT);
13167+
1315113168 return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
13152- DAG.getConstant(1, DL, UnextOp1VT) );
13169+ Constant );
1315313170}
1315413171
1315513172SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
0 commit comments