@@ -18372,31 +18372,6 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
1837218372 DAG.getBuildVector(VT, DL, RHSOps));
1837318373}
1837418374
18375- static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
18376- const SDLoc &DL, SelectionDAG &DAG,
18377- const RISCVSubtarget &Subtarget) {
18378- assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
18379- RISCVISD::VQDOTSU_VL == Opc);
18380- MVT VT = Op0.getSimpleValueType();
18381- assert(VT == Op1.getSimpleValueType() &&
18382- VT.getVectorElementType() == MVT::i32);
18383-
18384- SDValue Passthru = DAG.getConstant(0, DL, VT);
18385- MVT ContainerVT = VT;
18386- if (VT.isFixedLengthVector()) {
18387- ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18388- Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
18389- Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18390- Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18391- }
18392- auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
18393- SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
18394- {Op0, Op1, Passthru, Mask, VL});
18395- if (VT.isFixedLengthVector())
18396- return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18397- return LocalAccum;
18398- }
18399-
1840018375static MVT getQDOTXResultType(MVT OpVT) {
1840118376 ElementCount OpEC = OpVT.getVectorElementCount();
1840218377 assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
@@ -18455,61 +18430,62 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
1845518430 }
1845618431 }
1845718432
18458- // reduce ( zext a) <--> reduce (mul zext a. zext 1)
18459- // reduce ( sext a) <--> reduce (mul sext a. sext 1)
18433+ // zext a <--> partial_reduce_umla 0, a, 1
18434+ // sext a <--> partial_reduce_smla 0, a, 1
1846018435 if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
1846118436 InVec.getOpcode() == ISD::SIGN_EXTEND) {
1846218437 SDValue A = InVec.getOperand(0);
18463- if ( A.getValueType().getVectorElementType() != MVT::i8 ||
18464- !TLI.isTypeLegal(A.getValueType() ))
18438+ EVT OpVT = A.getValueType();
18439+ if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT ))
1846518440 return SDValue();
1846618441
1846718442 MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
18468- A = DAG.getBitcast(ResVT, A);
18469- SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
18470-
18443+ SDValue B = DAG.getConstant(0x1, DL, OpVT);
1847118444 bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
18472- unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
18473- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18445+ unsigned Opc =
18446+ IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
18447+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
1847418448 }
1847518449
18476- // mul (sext, sext) -> vqdot
18477- // mul (zext, zext) -> vqdotu
18478- // mul (sext, zext) -> vqdotsu
18479- // mul (zext, sext) -> vqdotsu (swapped)
18480- // TODO: Improve .vx handling - we end up with a sub-vector insert
18481- // which confuses the splat pattern matching. Also, match vqdotus.vx
18450+ // mul (sext a, sext b) -> partial_reduce_smla 0, a, b
18451+ // mul (zext a, zext b) -> partial_reduce_umla 0, a, b
18452+ // mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
18453+ // mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
1848218454 if (InVec.getOpcode() != ISD::MUL)
1848318455 return SDValue();
1848418456
1848518457 SDValue A = InVec.getOperand(0);
1848618458 SDValue B = InVec.getOperand(1);
18487- unsigned Opc = 0;
18488- if (A.getOpcode() == B.getOpcode()) {
18489- if (A.getOpcode() == ISD::SIGN_EXTEND)
18490- Opc = RISCVISD::VQDOT_VL;
18491- else if (A.getOpcode() == ISD::ZERO_EXTEND)
18492- Opc = RISCVISD::VQDOTU_VL;
18493- else
18494- return SDValue();
18495- } else {
18496- if (B.getOpcode() != ISD::ZERO_EXTEND)
18497- std::swap(A, B);
18498- if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
18499- return SDValue();
18500- Opc = RISCVISD::VQDOTSU_VL;
18501- }
18502- assert(Opc);
1850318459
18504- if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
18505- A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
18460+ if (!ISD::isExtOpcode(A.getOpcode()))
18461+ return SDValue();
18462+
18463+ EVT OpVT = A.getOperand(0).getValueType();
18464+ if (OpVT.getVectorElementType() != MVT::i8 ||
18465+ OpVT != B.getOperand(0).getValueType() ||
1850618466 !TLI.isTypeLegal(A.getValueType()))
1850718467 return SDValue();
1850818468
18509- MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
18510- A = DAG.getBitcast(ResVT, A.getOperand(0));
18511- B = DAG.getBitcast(ResVT, B.getOperand(0));
18512- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18469+ unsigned Opc;
18470+ if (A.getOpcode() == ISD::SIGN_EXTEND && B.getOpcode() == ISD::SIGN_EXTEND)
18471+ Opc = ISD::PARTIAL_REDUCE_SMLA;
18472+ else if (A.getOpcode() == ISD::ZERO_EXTEND &&
18473+ B.getOpcode() == ISD::ZERO_EXTEND)
18474+ Opc = ISD::PARTIAL_REDUCE_UMLA;
18475+ else if (A.getOpcode() == ISD::SIGN_EXTEND &&
18476+ B.getOpcode() == ISD::ZERO_EXTEND)
18477+ Opc = ISD::PARTIAL_REDUCE_SUMLA;
18478+ else if (A.getOpcode() == ISD::ZERO_EXTEND &&
18479+ B.getOpcode() == ISD::SIGN_EXTEND) {
18480+ Opc = ISD::PARTIAL_REDUCE_SUMLA;
18481+ std::swap(A, B);
18482+ } else
18483+ return SDValue();
18484+
18485+ MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
18486+ return DAG.getNode(
18487+ Opc, DL, ResVT,
18488+ {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
1851318489}
1851418490
1851518491static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
0 commit comments