@@ -618,6 +618,8 @@ namespace {
618618 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
619619 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620620 const TargetLowering &TLI);
621+ SDValue foldPartialReduceMLAMulOp(SDNode *N);
622+ SDValue foldPartialReduceMLANoMulOp(SDNode *N);
621623
622624 SDValue CombineExtLoad(SDNode *N);
623625 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12601,12 +12603,20 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1260112603 return SDValue();
1260212604}
1260312605
12606+ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12607+ if (SDValue Res = foldPartialReduceMLAMulOp(N))
12608+ return Res;
12609+ if (SDValue Res = foldPartialReduceMLANoMulOp(N))
12610+ return Res;
12611+ return SDValue();
12612+ }
12613+
1260412614// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
1260512615// -> partial_reduce_*mla(acc, a, b)
1260612616//
1260712617// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1260812618// -> partial_reduce_*mla(acc, x, C)
12609- SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA (SDNode *N) {
12619+ SDValue DAGCombiner::foldPartialReduceMLAMulOp (SDNode *N) {
1261012620 SDLoc DL(N);
1261112621 auto *Context = DAG.getContext();
1261212622 SDValue Acc = N->getOperand(0);
@@ -12672,6 +12682,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1267212682 RHSExtOp);
1267312683}
1267412684
12685+ // Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
12686+ // PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
12687+ // Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
12688+ // PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
12689+ SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12690+ SDLoc DL(N);
12691+ SDValue Acc = N->getOperand(0);
12692+ SDValue Op1 = N->getOperand(1);
12693+ SDValue Op2 = N->getOperand(2);
12694+
12695+ APInt ConstantOne;
12696+ if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12697+ !ConstantOne.isOne())
12698+ return SDValue();
12699+
12700+ unsigned Op1Opcode = Op1.getOpcode();
12701+ if (!ISD::isExtOpcode(Op1Opcode))
12702+ return SDValue();
12703+
12704+ SDValue UnextOp1 = Op1.getOperand(0);
12705+ EVT UnextOp1VT = UnextOp1.getValueType();
12706+
12707+ if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
12708+ return SDValue();
12709+
12710+ SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12711+
12712+ bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12713+
12714+ bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12715+ EVT AccElemVT = Acc.getValueType().getVectorElementType();
12716+ if (Op1IsSigned != NodeIsSigned &&
12717+ (Op1.getValueType().getVectorElementType() != AccElemVT ||
12718+ Op2.getValueType().getVectorElementType() != AccElemVT))
12719+ return SDValue();
12720+
12721+ unsigned NewOpcode =
12722+ Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12723+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12724+ TruncOp2);
12725+ }
12726+
1267512727SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
1267612728 auto *SLD = cast<VPStridedLoadSDNode>(N);
1267712729 EVT EltVT = SLD->getValueType(0).getVectorElementType();
0 commit comments