Skip to content
56 changes: 56 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ namespace {
SDValue visitMGATHER(SDNode *N);
SDValue visitMSCATTER(SDNode *N);
SDValue visitMHISTOGRAM(SDNode *N);
SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
SDValue visitVPGATHER(SDNode *N);
SDValue visitVPSCATTER(SDNode *N);
SDValue visitVP_STRIDED_LOAD(SDNode *N);
Expand Down Expand Up @@ -1972,6 +1973,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::MSCATTER: return visitMSCATTER(N);
case ISD::MSTORE: return visitMSTORE(N);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
Expand Down Expand Up @@ -12497,6 +12501,58 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}

// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
// Splat(1)) into
// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
// Splat(1)) into
// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);

SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);

APInt ConstantOne;
if (Op1->getOpcode() != ISD::MUL ||
!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
!ConstantOne.isOne())
return SDValue();

SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
unsigned LHSOpcode = LHS->getOpcode();
unsigned RHSOpcode = RHS->getOpcode();
if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
return SDValue();

// For a 2-stage extend the signedness of both of the extends must be the
// same. This is so the node can be folded into only a signed or unsigned
// node.
SDValue LHSExtOp = LHS->getOperand(0);
SDValue RHSExtOp = RHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
return SDValue();

// FIXME: Add a check to only perform the DAG combine if there is lowering
// provided by the target

bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;

bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (ExtIsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();

unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
RHSExtOp);
}

SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
Expand Down
139 changes: 82 additions & 57 deletions llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
;
; CHECK-NODOT-LABEL: udot:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: umull v3.8h, v2.8b, v1.8b
; CHECK-NODOT-NEXT: umull2 v1.8h, v2.16b, v1.16b
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v3.4h
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-NODOT-NEXT: ushll v4.8h, v2.8b, #0
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NODOT-NEXT: umlal v0.4s, v4.4h, v3.4h
; CHECK-NODOT-NEXT: umull v5.4s, v2.4h, v1.4h
; CHECK-NODOT-NEXT: umlal2 v0.4s, v2.8h, v1.8h
; CHECK-NODOT-NEXT: umlal2 v5.4s, v4.8h, v3.8h
; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
; CHECK-NODOT-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
Expand All @@ -35,17 +37,19 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
;
; CHECK-NODOT-LABEL: udot_narrow:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: umull v1.8h, v2.8b, v1.8b
; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NODOT-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
; CHECK-NODOT-NEXT: umull v3.4s, v2.4h, v1.4h
; CHECK-NODOT-NEXT: umull2 v4.4s, v2.8h, v1.8h
; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
; CHECK-NODOT-NEXT: umlal v0.4s, v2.4h, v1.4h
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
; CHECK-NODOT-NEXT: umlal v3.4s, v6.4h, v5.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
Expand All @@ -62,13 +66,15 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
;
; CHECK-NODOT-LABEL: sdot:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: smull v3.8h, v2.8b, v1.8b
; CHECK-NODOT-NEXT: smull2 v1.8h, v2.16b, v1.16b
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v3.4h
; CHECK-NODOT-NEXT: saddw2 v2.4s, v2.4s, v3.8h
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: sshll v3.8h, v1.8b, #0
; CHECK-NODOT-NEXT: sshll v4.8h, v2.8b, #0
; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
; CHECK-NODOT-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
Expand All @@ -85,17 +91,19 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
;
; CHECK-NODOT-LABEL: sdot_narrow:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: smull v1.8h, v2.8b, v1.8b
; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
; CHECK-NODOT-NEXT: sshll v2.8h, v2.8b, #0
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
Expand Down Expand Up @@ -223,19 +231,27 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
;
; CHECK-NODOT-LABEL: udot_8to64:
; CHECK-NODOT: // %bb.0: // %entry
; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
; CHECK-NODOT-NEXT: umull2 v2.8h, v2.16b, v3.16b
; CHECK-NODOT-NEXT: ushll v3.4s, v4.4h, #0
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
; CHECK-NODOT-NEXT: ushll v4.8h, v3.8b, #0
; CHECK-NODOT-NEXT: ushll v5.8h, v2.8b, #0
; CHECK-NODOT-NEXT: ushll2 v3.8h, v3.16b, #0
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NODOT-NEXT: ushll v6.4s, v4.4h, #0
; CHECK-NODOT-NEXT: ushll v7.4s, v5.4h, #0
; CHECK-NODOT-NEXT: ushll2 v4.4s, v4.8h, #0
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v3.2s
; CHECK-NODOT-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
; CHECK-NODOT-NEXT: uaddl v4.2d, v4.2s, v5.2s
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
; CHECK-NODOT-NEXT: ushll2 v5.4s, v5.8h, #0
; CHECK-NODOT-NEXT: ushll2 v16.4s, v3.8h, #0
; CHECK-NODOT-NEXT: ushll2 v17.4s, v2.8h, #0
; CHECK-NODOT-NEXT: ushll v3.4s, v3.4h, #0
; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #0
; CHECK-NODOT-NEXT: umlal2 v1.2d, v7.4s, v6.4s
; CHECK-NODOT-NEXT: umlal v0.2d, v7.2s, v6.2s
; CHECK-NODOT-NEXT: umull2 v18.2d, v5.4s, v4.4s
; CHECK-NODOT-NEXT: umull v4.2d, v5.2s, v4.2s
; CHECK-NODOT-NEXT: umlal2 v1.2d, v17.4s, v16.4s
; CHECK-NODOT-NEXT: umlal v0.2d, v17.2s, v16.2s
; CHECK-NODOT-NEXT: umlal2 v18.2d, v2.4s, v3.4s
; CHECK-NODOT-NEXT: umlal v4.2d, v2.2s, v3.2s
; CHECK-NODOT-NEXT: add v1.2d, v18.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
; CHECK-NODOT-NEXT: ret
entry:
Expand All @@ -258,19 +274,27 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
;
; CHECK-NODOT-LABEL: sdot_8to64:
; CHECK-NODOT: // %bb.0: // %entry
; CHECK-NODOT-NEXT: smull v4.8h, v2.8b, v3.8b
; CHECK-NODOT-NEXT: smull2 v2.8h, v2.16b, v3.16b
; CHECK-NODOT-NEXT: sshll v3.4s, v4.4h, #0
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
; CHECK-NODOT-NEXT: sshll v4.8h, v3.8b, #0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know these are regressions, but they'll be addressed by follow-up patches that further improve this code-gen.

; CHECK-NODOT-NEXT: sshll v5.8h, v2.8b, #0
; CHECK-NODOT-NEXT: sshll2 v3.8h, v3.16b, #0
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NODOT-NEXT: sshll v6.4s, v4.4h, #0
; CHECK-NODOT-NEXT: sshll v7.4s, v5.4h, #0
; CHECK-NODOT-NEXT: sshll2 v4.4s, v4.8h, #0
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v3.4s
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v3.2s
; CHECK-NODOT-NEXT: saddl2 v3.2d, v4.4s, v5.4s
; CHECK-NODOT-NEXT: saddl v4.2d, v4.2s, v5.2s
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
; CHECK-NODOT-NEXT: sshll2 v5.4s, v5.8h, #0
; CHECK-NODOT-NEXT: sshll2 v16.4s, v3.8h, #0
; CHECK-NODOT-NEXT: sshll2 v17.4s, v2.8h, #0
; CHECK-NODOT-NEXT: sshll v3.4s, v3.4h, #0
; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #0
; CHECK-NODOT-NEXT: smlal2 v1.2d, v7.4s, v6.4s
; CHECK-NODOT-NEXT: smlal v0.2d, v7.2s, v6.2s
; CHECK-NODOT-NEXT: smull2 v18.2d, v5.4s, v4.4s
; CHECK-NODOT-NEXT: smull v4.2d, v5.2s, v4.2s
; CHECK-NODOT-NEXT: smlal2 v1.2d, v17.4s, v16.4s
; CHECK-NODOT-NEXT: smlal v0.2d, v17.2s, v16.2s
; CHECK-NODOT-NEXT: smlal2 v18.2d, v2.4s, v3.4s
; CHECK-NODOT-NEXT: smlal v4.2d, v2.2s, v3.2s
; CHECK-NODOT-NEXT: add v1.2d, v18.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
; CHECK-NODOT-NEXT: ret
entry:
Expand Down Expand Up @@ -531,9 +555,10 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
; CHECK-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NEXT: umlal v0.4s, v2.4h, v1.4h
; CHECK-NEXT: umlal2 v0.4s, v2.8h, v1.8h
; CHECK-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
Expand Down
Loading
Loading