diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index ea1435c3934be..2edbda0418f34 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) { return SDValue(); } -// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)), -// Splat(1)) into -// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp). -// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)), -// Splat(1)) into -// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp). +// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1)) +// -> partial_reduce_*mla(acc, a, b) +// +// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) +// -> partial_reduce_*mla(acc, x, C) SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { SDLoc DL(N); - + auto *Context = DAG.getContext(); SDValue Acc = N->getOperand(0); SDValue Op1 = N->getOperand(1); SDValue Op2 = N->getOperand(2); - APInt ConstantOne; + APInt C; if (Op1->getOpcode() != ISD::MUL || - !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) || - !ConstantOne.isOne()) + !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne()) return SDValue(); SDValue LHS = Op1->getOperand(0); SDValue RHS = Op1->getOperand(1); unsigned LHSOpcode = LHS->getOpcode(); - unsigned RHSOpcode = RHS->getOpcode(); - if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode)) + if (!ISD::isExtOpcode(LHSOpcode)) return SDValue(); SDValue LHSExtOp = LHS->getOperand(0); - SDValue RHSExtOp = RHS->getOperand(0); EVT LHSExtOpVT = LHSExtOp.getValueType(); - if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode) - return SDValue(); - // Only perform the DAG combine if there is custom lowering provided by the - // target - auto *Context = DAG.getContext(); + // Only perform these combines if the target supports folding + // the extends into the operation. if (!TLI.isPartialReduceMLALegalOrCustom( TLI.getTypeToTransformTo(*Context, N->getValueType(0)), TLI.getTypeToTransformTo(*Context, LHSExtOpVT))) return SDValue(); bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND; + unsigned NewOpcode = + ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; + + // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) + // -> partial_reduce_*mla(acc, x, C) + if (ISD::isConstantSplatVector(RHS.getNode(), C)) { + APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits()); + unsigned LHSBits = LHS.getValueType().getScalarSizeInBits(); + if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) && + (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C)) + return SDValue(); + + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, + DAG.getConstant(CTrunc, DL, LHSExtOpVT)); + } + + unsigned RHSOpcode = RHS->getOpcode(); + if (!ISD::isExtOpcode(RHSOpcode)) + return SDValue(); + + SDValue RHSExtOp = RHS->getOperand(0); + if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode) + return SDValue(); // For a 2-stage extend the signedness of both of the extends must be the // same. This is so the node can be folded into only a signed or unsigned @@ -12663,8 +12679,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { Op1.getValueType().getVectorElementType() != AccElemVT) return SDValue(); - unsigned NewOpcode = - ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp); } diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 039cac01008b8..5326bccbbc3d5 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -1139,7 +1139,6 @@ entry: ret %partial.reduce } - define @partial_reduce_only_split_acc( %acc, %a, %b) { ; CHECK-LABEL: partial_reduce_only_split_acc: ; CHECK: // %bb.0: // %entry @@ -1178,3 +1177,145 @@ entry: %acc, %mult) ret %partial.reduce } + +define @sdot_imm( %acc, %a) { +; CHECK-LABEL: sdot_imm: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sunpklo z2.h, z1.b +; CHECK-NEXT: sunpkhi z1.h, z1.b +; CHECK-NEXT: sunpklo z3.s, z2.h +; CHECK-NEXT: sunpkhi z2.s, z2.h +; CHECK-NEXT: sub z0.s, z0.s, z3.s +; CHECK-NEXT: sunpklo z3.s, z1.h +; CHECK-NEXT: sunpkhi z1.s, z1.h +; CHECK-NEXT: sub z0.s, z0.s, z2.s +; CHECK-NEXT: sub z0.s, z0.s, z3.s +; CHECK-NEXT: sub z0.s, z0.s, z1.s +; CHECK-NEXT: ret +; +; CHECK-NEWLOWERING-LABEL: sdot_imm: +; CHECK-NEWLOWERING: // %bb.0: // %entry +; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff +; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b +; CHECK-NEWLOWERING-NEXT: ret +entry: + %a.wide = sext %a to + %mult = mul nuw nsw %a.wide, splat(i32 -1) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) + ret %partial.reduce +} + +define @sdot_imm_does_not_fit( %acc, %a) { +; CHECK-LABEL: sdot_imm_does_not_fit: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sunpklo z2.h, z1.b +; CHECK-NEXT: sunpkhi z1.h, z1.b +; CHECK-NEXT: sunpklo z3.s, z2.h +; CHECK-NEXT: sunpkhi z2.s, z2.h +; CHECK-NEXT: sunpklo z4.s, z1.h +; CHECK-NEXT: sunpkhi z1.s, z1.h +; CHECK-NEXT: lsl z4.s, z4.s, #8 +; CHECK-NEXT: lsl z2.s, z2.s, #8 +; CHECK-NEXT: lsl z3.s, z3.s, #8 +; CHECK-NEXT: lsl z1.s, z1.s, #8 +; CHECK-NEXT: add z0.s, z0.s, z3.s +; CHECK-NEXT: add z2.s, z2.s, z4.s +; CHECK-NEXT: add z0.s, z0.s, z2.s +; CHECK-NEXT: add z0.s, z0.s, z1.s +; CHECK-NEXT: ret +; +; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit: +; CHECK-NEWLOWERING: // %bb.0: // %entry +; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b +; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b +; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h +; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h +; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z1.h +; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h +; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8 +; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8 +; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8 +; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8 +; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s +; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s +; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s +; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s +; CHECK-NEWLOWERING-NEXT: ret +entry: + %a.wide = sext %a to + %mult = mul nuw nsw %a.wide, splat(i32 256) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) + ret %partial.reduce +} + +define @udot_imm( %acc, %a) { +; CHECK-LABEL: udot_imm: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: uunpklo z3.h, z1.b +; CHECK-NEXT: mov z2.s, #255 // =0xff +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: uunpkhi z1.h, z1.b +; CHECK-NEXT: uunpklo z4.s, z3.h +; CHECK-NEXT: uunpkhi z3.s, z3.h +; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s +; CHECK-NEXT: uunpklo z4.s, z1.h +; CHECK-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEXT: mla z0.s, p0/m, z3.s, z2.s +; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s +; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +; +; CHECK-NEWLOWERING-LABEL: udot_imm: +; CHECK-NEWLOWERING: // %bb.0: // %entry +; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff +; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b +; CHECK-NEWLOWERING-NEXT: ret +entry: + %a.wide = zext %a to + %mult = mul nuw nsw %a.wide, splat(i32 255) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) + ret %partial.reduce +} + +define @udot_imm_does_not_fit( %acc, %a) { +; CHECK-LABEL: udot_imm_does_not_fit: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: uunpklo z2.h, z1.b +; CHECK-NEXT: uunpkhi z1.h, z1.b +; CHECK-NEXT: uunpklo z3.s, z2.h +; CHECK-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEXT: uunpklo z4.s, z1.h +; CHECK-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEXT: lsl z4.s, z4.s, #8 +; CHECK-NEXT: lsl z2.s, z2.s, #8 +; CHECK-NEXT: lsl z3.s, z3.s, #8 +; CHECK-NEXT: lsl z1.s, z1.s, #8 +; CHECK-NEXT: add z0.s, z0.s, z3.s +; CHECK-NEXT: add z2.s, z2.s, z4.s +; CHECK-NEXT: add z0.s, z0.s, z2.s +; CHECK-NEXT: add z0.s, z0.s, z1.s +; CHECK-NEXT: ret +; +; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit: +; CHECK-NEWLOWERING: // %bb.0: // %entry +; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b +; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b +; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h +; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z1.h +; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8 +; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8 +; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8 +; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8 +; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s +; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s +; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s +; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s +; CHECK-NEWLOWERING-NEXT: ret +entry: + %a.wide = zext %a to + %mult = mul nuw nsw %a.wide, splat(i32 256) + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) + ret %partial.reduce +}