Skip to content

Commit 24a613f

Browse files
[AArch64][SVE] Add dot product lowering for PARTIAL_REDUCE_MLA node
Add lowering in tablegen for PARTIAL_REDUCE_U/SMLA ISD nodes. Only happens when the combine has been performed on the ISD node. Also adds in check to only do the DAG combine when the node can then eventually be lowered, so changes neon tests too.
1 parent e694bcf commit 24a613f

File tree

10 files changed

+176
-240
lines changed

10 files changed

+176
-240
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,6 +1639,25 @@ class TargetLoweringBase {
16391639
getCondCodeAction(CC, VT) == Custom;
16401640
}
16411641

1642+
/// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
1643+
/// InputVT should be treated. Either it's legal, needs to be promoted to a
1644+
/// larger size, needs to be expanded to some other code sequence, or the
1645+
/// target has a custom expander for it.
1646+
LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
1647+
unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
1648+
unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
1649+
assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
1650+
"Table isn't big enough!");
1651+
return PartialReduceMLAActions[AccI][InputI];
1652+
}
1653+
1654+
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
1655+
/// legal or custom for this target.
1656+
bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
1657+
return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
1658+
getPartialReduceMLAAction(AccVT, InputVT) == Custom;
1659+
}
1660+
16421661
/// If the action for this operation is to promote, this method returns the
16431662
/// ValueType to promote to.
16441663
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2704,6 +2723,16 @@ class TargetLoweringBase {
27042723
setCondCodeAction(CCs, VT, Action);
27052724
}
27062725

2726+
/// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
2727+
/// type InputVT should be treated by the target. Either it's legal, needs to
2728+
/// be promoted to a larger size, needs to be expanded to some other code
2729+
/// sequence, or the target has a custom expander for it.
2730+
void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
2731+
LegalizeAction Action) {
2732+
assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
2733+
PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
2734+
}
2735+
27072736
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
27082737
/// to trying a larger integer/fp until it can find one that works. If that
27092738
/// default is insufficient, this method can be used by the target to override
@@ -3650,6 +3679,12 @@ class TargetLoweringBase {
36503679
/// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
36513680
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
36523681

3682+
/// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
3683+
/// nodes, keep a LegalizeAction which indicates how instruction selection
3684+
/// should deal with this operation.
3685+
LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
3686+
[MVT::VALUETYPE_SIZE];
3687+
36533688
ValueTypeActionImpl ValueTypeActions;
36543689

36553690
private:

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ def SDTSubVecInsert : SDTypeProfile<1, 3, [ // subvector insert
313313
SDTCisSubVecOfVec<2, 1>, SDTCisSameAs<0,1>, SDTCisInt<3>
314314
]>;
315315

316+
def SDTPartialReduceMLA : SDTypeProfile<1, 3, [ // partial reduce mla
317+
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>
318+
]>;
319+
316320
def SDTPrefetch : SDTypeProfile<0, 4, [ // prefetch
317321
SDTCisPtrTy<0>, SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisInt<1>
318322
]>;
@@ -513,6 +517,11 @@ def vecreduce_fmax : SDNode<"ISD::VECREDUCE_FMAX", SDTFPVecReduce>;
513517
def vecreduce_fminimum : SDNode<"ISD::VECREDUCE_FMINIMUM", SDTFPVecReduce>;
514518
def vecreduce_fmaximum : SDNode<"ISD::VECREDUCE_FMAXIMUM", SDTFPVecReduce>;
515519

520+
def partial_reduce_umla : SDNode<"ISD::PARTIAL_REDUCE_UMLA",
521+
SDTPartialReduceMLA>;
522+
def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
523+
SDTPartialReduceMLA>;
524+
516525
def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>;
517526
def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>;
518527
def fmul : SDNode<"ISD::FMUL" , SDTFPBinOp, [SDNPCommutative]>;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12535,8 +12535,10 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1253512535
if (LHSExtOpVT != RHSExtOp.getValueType())
1253612536
return SDValue();
1253712537

12538-
// FIXME: Add a check to only perform the DAG combine if there is lowering
12539-
// provided by the target
12538+
// Only perform the DAG combine if there is custom lowering provided by the
12539+
// target
12540+
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), LHSExtOpVT))
12541+
return SDValue();
1254012542

1254112543
bool LHSIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
1254212544
bool RHSIsSigned = RHSOpcode == ISD::SIGN_EXTEND;

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
469469
case ISD::VECTOR_COMPRESS:
470470
case ISD::SCMP:
471471
case ISD::UCMP:
472-
case ISD::PARTIAL_REDUCE_UMLA:
473-
case ISD::PARTIAL_REDUCE_SMLA:
474472
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
475473
break;
476474
case ISD::SMULFIX:
@@ -524,6 +522,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
524522
Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
525523
break;
526524
}
525+
case ISD::PARTIAL_REDUCE_UMLA:
526+
case ISD::PARTIAL_REDUCE_SMLA:
527+
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
528+
Node->getOperand(1).getValueType());
529+
break;
527530

528531
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
529532
case ISD::VPID: { \

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
836836
setOperationAction(ISD::SET_FPENV, VT, Expand);
837837
setOperationAction(ISD::RESET_FPENV, VT, Expand);
838838

839-
// PartialReduceMLA operations default to expand.
840-
setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
841-
Expand);
839+
for (MVT InputVT : MVT::all_valuetypes())
840+
setPartialReduceMLAAction(VT, InputVT, Expand);
842841
}
843842

844843
// Most targets ignore the @llvm.prefetch intrinsic.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,21 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15811581
setOperationAction(ISD::MSTORE, VT, Custom);
15821582
}
15831583

1584+
for (MVT VT : MVT::integer_scalable_vector_valuetypes()) {
1585+
if (!EnablePartialReduceNodes)
1586+
break;
1587+
for (MVT InnerVT : MVT::integer_scalable_vector_valuetypes()) {
1588+
ElementCount VTElemCount = VT.getVectorElementCount();
1589+
if (VTElemCount.getKnownMinValue() == 1)
1590+
continue;
1591+
if (VTElemCount * 4 == InnerVT.getVectorElementCount())
1592+
setPartialReduceMLAAction(VT, InnerVT, Custom);
1593+
if (InnerVT.getVectorElementType().getSizeInBits() * 4 ==
1594+
VT.getVectorElementType().getSizeInBits())
1595+
setPartialReduceMLAAction(VT, InnerVT, Legal);
1596+
}
1597+
}
1598+
15841599
// Firstly, exclude all scalable vector extending loads/truncating stores,
15851600
// include both integer and floating scalable vector.
15861601
for (MVT VT : MVT::scalable_vector_valuetypes()) {

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def HasFuseAES : Predicate<"Subtarget->hasFuseAES()">,
143143
"fuse-aes">;
144144
def HasSVE : Predicate<"Subtarget->isSVEAvailable()">,
145145
AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">;
146+
def HasSVEorStreamingSVE
147+
: Predicate<"Subtarget->isSVEorStreamingSVEAvailable()">,
148+
AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">;
146149
def HasSVEB16B16 : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSVEB16B16()">,
147150
AssemblerPredicateWithAll<(all_of FeatureSVEB16B16), "sve-b16b16">;
148151
def HasSVE2 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2()">,

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,17 @@ let Predicates = [HasSVE_or_SME] in {
655655
defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>;
656656
defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>;
657657

658+
let Predicates = [HasSVEorStreamingSVE] in {
659+
def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)),
660+
(UDOT_ZZZ_S $Acc, $MulLHS, $MulRHS)>;
661+
def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)),
662+
(SDOT_ZZZ_S $Acc, $MulLHS, $MulRHS)>;
663+
def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
664+
(UDOT_ZZZ_D $Acc, $MulLHS, $MulRHS)>;
665+
def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
666+
(SDOT_ZZZ_D $Acc, $MulLHS, $MulRHS)>;
667+
} // End HasSVEorStreamingSVE
668+
658669
defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>;
659670
defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>;
660671

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 57 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
1212
;
1313
; CHECK-NODOT-LABEL: udot:
1414
; CHECK-NODOT: // %bb.0:
15-
; CHECK-NODOT-NEXT: ushll v3.8h, v1.8b, #0
16-
; CHECK-NODOT-NEXT: ushll v4.8h, v2.8b, #0
17-
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
18-
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
19-
; CHECK-NODOT-NEXT: umlal v0.4s, v4.4h, v3.4h
20-
; CHECK-NODOT-NEXT: umull v5.4s, v2.4h, v1.4h
21-
; CHECK-NODOT-NEXT: umlal2 v0.4s, v2.8h, v1.8h
22-
; CHECK-NODOT-NEXT: umlal2 v5.4s, v4.8h, v3.8h
23-
; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
15+
; CHECK-NODOT-NEXT: umull v3.8h, v2.8b, v1.8b
16+
; CHECK-NODOT-NEXT: umull2 v1.8h, v2.16b, v1.16b
17+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
18+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v3.4h
19+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
20+
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
21+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
2422
; CHECK-NODOT-NEXT: ret
2523
%u.wide = zext <16 x i8> %u to <16 x i32>
2624
%s.wide = zext <16 x i8> %s to <16 x i32>
@@ -37,19 +35,17 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
3735
;
3836
; CHECK-NODOT-LABEL: udot_narrow:
3937
; CHECK-NODOT: // %bb.0:
40-
; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
41-
; CHECK-NODOT-NEXT: ushll v2.8h, v2.8b, #0
38+
; CHECK-NODOT-NEXT: umull v1.8h, v2.8b, v1.8b
4239
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
43-
; CHECK-NODOT-NEXT: umull v3.4s, v2.4h, v1.4h
44-
; CHECK-NODOT-NEXT: umull2 v4.4s, v2.8h, v1.8h
45-
; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
46-
; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
47-
; CHECK-NODOT-NEXT: umlal v0.4s, v2.4h, v1.4h
40+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
41+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
42+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
43+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
4844
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
49-
; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
50-
; CHECK-NODOT-NEXT: umlal v3.4s, v6.4h, v5.4h
51-
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
45+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
5246
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
47+
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
48+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
5349
; CHECK-NODOT-NEXT: ret
5450
%u.wide = zext <8 x i8> %u to <8 x i32>
5551
%s.wide = zext <8 x i8> %s to <8 x i32>
@@ -66,15 +62,13 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
6662
;
6763
; CHECK-NODOT-LABEL: sdot:
6864
; CHECK-NODOT: // %bb.0:
69-
; CHECK-NODOT-NEXT: sshll v3.8h, v1.8b, #0
70-
; CHECK-NODOT-NEXT: sshll v4.8h, v2.8b, #0
71-
; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
72-
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
73-
; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
74-
; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
75-
; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
76-
; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
77-
; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
65+
; CHECK-NODOT-NEXT: smull v3.8h, v2.8b, v1.8b
66+
; CHECK-NODOT-NEXT: smull2 v1.8h, v2.16b, v1.16b
67+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
68+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v3.4h
69+
; CHECK-NODOT-NEXT: saddw2 v2.4s, v2.4s, v3.8h
70+
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
71+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
7872
; CHECK-NODOT-NEXT: ret
7973
%u.wide = sext <16 x i8> %u to <16 x i32>
8074
%s.wide = sext <16 x i8> %s to <16 x i32>
@@ -91,19 +85,17 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
9185
;
9286
; CHECK-NODOT-LABEL: sdot_narrow:
9387
; CHECK-NODOT: // %bb.0:
94-
; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
95-
; CHECK-NODOT-NEXT: sshll v2.8h, v2.8b, #0
88+
; CHECK-NODOT-NEXT: smull v1.8h, v2.8b, v1.8b
9689
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
97-
; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
98-
; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
99-
; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
100-
; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
101-
; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
90+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
91+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
92+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
93+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
10294
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
103-
; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
104-
; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
105-
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
95+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
10696
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
97+
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
98+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
10799
; CHECK-NODOT-NEXT: ret
108100
%u.wide = sext <8 x i8> %u to <8 x i32>
109101
%s.wide = sext <8 x i8> %s to <8 x i32>
@@ -231,27 +223,19 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
231223
;
232224
; CHECK-NODOT-LABEL: udot_8to64:
233225
; CHECK-NODOT: // %bb.0: // %entry
234-
; CHECK-NODOT-NEXT: ushll v4.8h, v3.8b, #0
235-
; CHECK-NODOT-NEXT: ushll v5.8h, v2.8b, #0
236-
; CHECK-NODOT-NEXT: ushll2 v3.8h, v3.16b, #0
237-
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
238-
; CHECK-NODOT-NEXT: ushll v6.4s, v4.4h, #0
239-
; CHECK-NODOT-NEXT: ushll v7.4s, v5.4h, #0
226+
; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
227+
; CHECK-NODOT-NEXT: umull2 v2.8h, v2.16b, v3.16b
228+
; CHECK-NODOT-NEXT: ushll v3.4s, v4.4h, #0
229+
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
240230
; CHECK-NODOT-NEXT: ushll2 v4.4s, v4.8h, #0
241-
; CHECK-NODOT-NEXT: ushll2 v5.4s, v5.8h, #0
242-
; CHECK-NODOT-NEXT: ushll2 v16.4s, v3.8h, #0
243-
; CHECK-NODOT-NEXT: ushll2 v17.4s, v2.8h, #0
244-
; CHECK-NODOT-NEXT: ushll v3.4s, v3.4h, #0
245-
; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #0
246-
; CHECK-NODOT-NEXT: umlal2 v1.2d, v7.4s, v6.4s
247-
; CHECK-NODOT-NEXT: umlal v0.2d, v7.2s, v6.2s
248-
; CHECK-NODOT-NEXT: umull2 v18.2d, v5.4s, v4.4s
249-
; CHECK-NODOT-NEXT: umull v4.2d, v5.2s, v4.2s
250-
; CHECK-NODOT-NEXT: umlal2 v1.2d, v17.4s, v16.4s
251-
; CHECK-NODOT-NEXT: umlal v0.2d, v17.2s, v16.2s
252-
; CHECK-NODOT-NEXT: umlal2 v18.2d, v2.4s, v3.4s
253-
; CHECK-NODOT-NEXT: umlal v4.2d, v2.2s, v3.2s
254-
; CHECK-NODOT-NEXT: add v1.2d, v18.2d, v1.2d
231+
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
232+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
233+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v3.2s
234+
; CHECK-NODOT-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
235+
; CHECK-NODOT-NEXT: uaddl v4.2d, v4.2s, v5.2s
236+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
237+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
238+
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
255239
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
256240
; CHECK-NODOT-NEXT: ret
257241
entry:
@@ -274,27 +258,19 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
274258
;
275259
; CHECK-NODOT-LABEL: sdot_8to64:
276260
; CHECK-NODOT: // %bb.0: // %entry
277-
; CHECK-NODOT-NEXT: sshll v4.8h, v3.8b, #0
278-
; CHECK-NODOT-NEXT: sshll v5.8h, v2.8b, #0
279-
; CHECK-NODOT-NEXT: sshll2 v3.8h, v3.16b, #0
280-
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
281-
; CHECK-NODOT-NEXT: sshll v6.4s, v4.4h, #0
282-
; CHECK-NODOT-NEXT: sshll v7.4s, v5.4h, #0
261+
; CHECK-NODOT-NEXT: smull v4.8h, v2.8b, v3.8b
262+
; CHECK-NODOT-NEXT: smull2 v2.8h, v2.16b, v3.16b
263+
; CHECK-NODOT-NEXT: sshll v3.4s, v4.4h, #0
264+
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
283265
; CHECK-NODOT-NEXT: sshll2 v4.4s, v4.8h, #0
284-
; CHECK-NODOT-NEXT: sshll2 v5.4s, v5.8h, #0
285-
; CHECK-NODOT-NEXT: sshll2 v16.4s, v3.8h, #0
286-
; CHECK-NODOT-NEXT: sshll2 v17.4s, v2.8h, #0
287-
; CHECK-NODOT-NEXT: sshll v3.4s, v3.4h, #0
288-
; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #0
289-
; CHECK-NODOT-NEXT: smlal2 v1.2d, v7.4s, v6.4s
290-
; CHECK-NODOT-NEXT: smlal v0.2d, v7.2s, v6.2s
291-
; CHECK-NODOT-NEXT: smull2 v18.2d, v5.4s, v4.4s
292-
; CHECK-NODOT-NEXT: smull v4.2d, v5.2s, v4.2s
293-
; CHECK-NODOT-NEXT: smlal2 v1.2d, v17.4s, v16.4s
294-
; CHECK-NODOT-NEXT: smlal v0.2d, v17.2s, v16.2s
295-
; CHECK-NODOT-NEXT: smlal2 v18.2d, v2.4s, v3.4s
296-
; CHECK-NODOT-NEXT: smlal v4.2d, v2.2s, v3.2s
297-
; CHECK-NODOT-NEXT: add v1.2d, v18.2d, v1.2d
266+
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
267+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v3.4s
268+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v3.2s
269+
; CHECK-NODOT-NEXT: saddl2 v3.2d, v4.4s, v5.4s
270+
; CHECK-NODOT-NEXT: saddl v4.2d, v4.2s, v5.2s
271+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
272+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
273+
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
298274
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
299275
; CHECK-NODOT-NEXT: ret
300276
entry:
@@ -555,10 +531,9 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
555531
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
556532
; CHECK-LABEL: not_udot:
557533
; CHECK: // %bb.0:
558-
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
559-
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
560-
; CHECK-NEXT: umlal v0.4s, v2.4h, v1.4h
561-
; CHECK-NEXT: umlal2 v0.4s, v2.8h, v1.8h
534+
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
535+
; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
536+
; CHECK-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
562537
; CHECK-NEXT: ret
563538
%u.wide = zext <8 x i8> %u to <8 x i32>
564539
%s.wide = zext <8 x i8> %s to <8 x i32>

0 commit comments

Comments
 (0)