Skip to content

Commit a20bced

Browse files
committed
[AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to USDOT
Add lowering for PARTIAL_REDUCE_U/SMLA nodes to USDOT instructions. This happens when there is a MUL instruction as the second operand in the ISD node. Then the extends on the operands of the MUL op need to have a different signedness.
1 parent 8d9f516 commit a20bced

File tree

4 files changed

+107
-111
lines changed

4 files changed

+107
-111
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,8 +924,19 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
924924
/// illegal ResNo in that case.
925925
bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
926926
// See if the target wants to custom lower this node.
927-
if (TLI.getOperationAction(N->getOpcode(), VT) != TargetLowering::Custom)
928-
return false;
927+
unsigned Opcode = N->getOpcode();
928+
bool IsPRMLAOpcode =
929+
Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;
930+
931+
if (IsPRMLAOpcode) {
932+
if (TLI.getPartialReduceMLAAction(N->getValueType(0),
933+
N->getOperand(1).getValueType()) !=
934+
TargetLowering::Custom)
935+
return false;
936+
} else {
937+
if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
938+
return false;
939+
}
929940

930941
SmallVector<SDValue, 8> Results;
931942
if (LegalizeResult)

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7742,8 +7742,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77427742
return LowerFLDEXP(Op, DAG);
77437743
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
77447744
return LowerVECTOR_HISTOGRAM(Op, DAG);
7745-
case ISD::PARTIAL_REDUCE_SMLA:
77467745
case ISD::PARTIAL_REDUCE_UMLA:
7746+
case ISD::PARTIAL_REDUCE_SMLA:
77477747
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
77487748
}
77497749
}
@@ -27533,6 +27533,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
2753327533
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
2753427534
Results.push_back(Res);
2753527535
return;
27536+
case ISD::PARTIAL_REDUCE_UMLA:
27537+
case ISD::PARTIAL_REDUCE_SMLA:
27538+
Results.push_back(LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG));
27539+
return;
2753627540
case ISD::ADD:
2753727541
case ISD::FADD:
2753827542
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
@@ -29515,6 +29519,80 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2951529519
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
2951629520
}
2951729521

29522+
// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), SEXT(MulOpRHS)), Splat 1)
29523+
// to USDOT(Acc, MulOpLHS, MulOpRHS)
29524+
// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), ZEXT(MulOpRHS)), Splat 1)
29525+
// to USDOT(Acc, MulOpRHS, MulOpLHS)
29526+
SDValue
29527+
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op,
29528+
SelectionDAG &DAG) const {
29529+
bool Scalable = Op.getValueType().isScalableVector();
29530+
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
29531+
if (Scalable && !Subtarget.isSVEorStreamingSVEAvailable())
29532+
return SDValue();
29533+
if (!Scalable && (!Subtarget.isNeonAvailable() || !Subtarget.hasDotProd()))
29534+
return SDValue();
29535+
if (!Subtarget.hasMatMulInt8())
29536+
return SDValue();
29537+
SDLoc DL(Op);
29538+
29539+
if (Op.getOperand(1).getOpcode() != ISD::MUL)
29540+
return SDValue();
29541+
29542+
SDValue Acc = Op.getOperand(0);
29543+
SDValue Mul = Op.getOperand(1);
29544+
29545+
APInt ConstantOne;
29546+
if (!ISD::isConstantSplatVector(Op.getOperand(2).getNode(), ConstantOne) ||
29547+
!ConstantOne.isOne())
29548+
return SDValue();
29549+
29550+
SDValue ExtMulOpLHS = Mul.getOperand(0);
29551+
SDValue ExtMulOpRHS = Mul.getOperand(1);
29552+
unsigned ExtMulOpLHSOpcode = ExtMulOpLHS.getOpcode();
29553+
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS.getOpcode();
29554+
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
29555+
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
29556+
return SDValue();
29557+
29558+
SDValue MulOpLHS = ExtMulOpLHS.getOperand(0);
29559+
SDValue MulOpRHS = ExtMulOpRHS.getOperand(0);
29560+
EVT MulOpLHSVT = MulOpLHS.getValueType();
29561+
if (MulOpLHSVT != MulOpRHS.getValueType())
29562+
return SDValue();
29563+
29564+
bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
29565+
bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
29566+
if (LHSIsSigned == RHSIsSigned)
29567+
return SDValue();
29568+
29569+
EVT AccVT = Acc.getValueType();
29570+
// There is no nxv2i64 version of usdot
29571+
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
29572+
return SDValue();
29573+
29574+
// USDOT expects the signed operand to be last
29575+
if (!RHSIsSigned)
29576+
std::swap(MulOpLHS, MulOpRHS);
29577+
29578+
unsigned Opcode = AArch64ISD::USDOT;
29579+
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
29580+
// product followed by a zero / sign extension
29581+
// Don't want this to be split because there is no nxv2i64 version of usdot
29582+
if ((AccVT == MVT::nxv4i64 && MulOpLHSVT == MVT::nxv16i8) ||
29583+
(AccVT == MVT::v4i64 && MulOpLHSVT == MVT::v16i8)) {
29584+
EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
29585+
29586+
SDValue DotI32 =
29587+
DAG.getNode(Opcode, DL, AccVTI32, DAG.getConstant(0, DL, AccVTI32),
29588+
MulOpLHS, MulOpRHS);
29589+
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
29590+
return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
29591+
}
29592+
29593+
return DAG.getNode(Opcode, DL, AccVT, Acc, MulOpLHS, MulOpRHS);
29594+
}
29595+
2951829596
SDValue
2951929597
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
2952029598
SelectionDAG &DAG) const {

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,7 @@ class AArch64TargetLowering : public TargetLowering {
11821182
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
11831183
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
11841184
SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
1185+
SDValue LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op, SelectionDAG &DAG) const;
11851186
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
11861187
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
11871188
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 14 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -106,23 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
106106
;
107107
; CHECK-NEWLOWERING-LABEL: usdot:
108108
; CHECK-NEWLOWERING: // %bb.0: // %entry
109-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
110-
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
111-
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
112-
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
113-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
114-
; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
115-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
116-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
117-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
118-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
119-
; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z1.h
120-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z2.h
121-
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
122-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
123-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
124-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
125-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
109+
; CHECK-NEWLOWERING-NEXT: usdot z0.s, z1.b, z2.b
126110
; CHECK-NEWLOWERING-NEXT: ret
127111
entry:
128112
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -161,23 +145,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
161145
;
162146
; CHECK-NEWLOWERING-LABEL: sudot:
163147
; CHECK-NEWLOWERING: // %bb.0: // %entry
164-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
165-
; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
166-
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
167-
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
168-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
169-
; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
170-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
171-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
172-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
173-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
174-
; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z1.h
175-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z2.h
176-
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
177-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
178-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
179-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
180-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
148+
; CHECK-NEWLOWERING-NEXT: usdot z0.s, z2.b, z1.b
181149
; CHECK-NEWLOWERING-NEXT: ret
182150
entry:
183151
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -331,43 +299,12 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
331299
;
332300
; CHECK-NEWLOWERING-LABEL: usdot_8to64:
333301
; CHECK-NEWLOWERING: // %bb.0: // %entry
334-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.h, z2.b
335-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
336-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.h, z3.b
337-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z3.b
338-
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
339-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
340-
; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z2.h
341-
; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z5.h
342-
; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
343-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
344-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
345-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
346-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
347-
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
348-
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
349-
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z24.s
350-
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z25.s
351-
; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
352-
; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
353-
; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
354-
; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
355-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
356-
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z4.s
357-
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z5.s
358-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
359-
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z2.s
360-
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z3.s
361-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
362-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
363-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
364-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
365-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
366-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
367-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
368-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
369-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
370-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
302+
; CHECK-NEWLOWERING-NEXT: mov z4.s, #0 // =0x0
303+
; CHECK-NEWLOWERING-NEXT: usdot z4.s, z2.b, z3.b
304+
; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z4.s
305+
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z4.s
306+
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
307+
; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
371308
; CHECK-NEWLOWERING-NEXT: ret
372309
entry:
373310
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -432,43 +369,12 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
432369
;
433370
; CHECK-NEWLOWERING-LABEL: sudot_8to64:
434371
; CHECK-NEWLOWERING: // %bb.0: // %entry
435-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.h, z2.b
436-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
437-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.h, z3.b
438-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z3.b
439-
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
440-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
441-
; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z2.h
442-
; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z5.h
443-
; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
444-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
445-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
446-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
447-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
448-
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
449-
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
450-
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z24.s
451-
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z25.s
452-
; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
453-
; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
454-
; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
455-
; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
456-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
457-
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z4.s
458-
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z5.s
459-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
460-
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z2.s
461-
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z3.s
462-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
463-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
464-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
465-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
466-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
467-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
468-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
469-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
470-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
471-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
372+
; CHECK-NEWLOWERING-NEXT: mov z4.s, #0 // =0x0
373+
; CHECK-NEWLOWERING-NEXT: usdot z4.s, z3.b, z2.b
374+
; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z4.s
375+
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z4.s
376+
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
377+
; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
472378
; CHECK-NEWLOWERING-NEXT: ret
473379
entry:
474380
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>

0 commit comments

Comments
 (0)