Skip to content

Commit 1ed6ae3

Browse files
JamesChestermanNickGuy-Arm
authored andcommitted
[SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes
Lowering for fixed width vectors added to tablegen. There is also custom lowering to ensure that the USDOT patterns are still lowered for fixed width vectors. It also ensures that the v16i8 -> v4i64 partial reduction case is lowered here instead of being split (as there is not a v2i64 dot product instruction).
1 parent 23f6358 commit 1ed6ae3

File tree

3 files changed

+98
-25
lines changed

3 files changed

+98
-25
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7744,7 +7744,11 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77447744
return LowerVECTOR_HISTOGRAM(Op, DAG);
77457745
case ISD::PARTIAL_REDUCE_SMLA:
77467746
case ISD::PARTIAL_REDUCE_UMLA:
7747-
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
7747+
case ISD::PARTIAL_REDUCE_SMLA: {
7748+
if (SDValue Result = LowerPARTIAL_REDUCE_MLA(Op, DAG))
7749+
return Result;
7750+
return expandPartialReduceMLA(Op.getNode(), DAG);
7751+
}
77487752
}
77497753
}
77507754

@@ -27569,6 +27573,15 @@ void AArch64TargetLowering::ReplaceNodeResults(
2756927573
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
2757027574
Results.push_back(Res);
2757127575
return;
27576+
case ISD::PARTIAL_REDUCE_UMLA:
27577+
case ISD::PARTIAL_REDUCE_SMLA: {
27578+
SDValue Res;
27579+
if (Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
27580+
Results.push_back(Res);
27581+
else
27582+
Results.push_back(expandPartialReduceMLA(N, DAG));
27583+
return;
27584+
}
2757227585
case ISD::ADD:
2757327586
case ISD::FADD:
2757427587
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
@@ -29524,6 +29537,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2952429537
SDValue
2952529538
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2952629539
SelectionDAG &DAG) const {
29540+
2952729541
SDLoc DL(Op);
2952829542

2952929543
SDValue Acc = Op.getOperand(0);

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,17 @@ defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>;
14741474
defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>;
14751475
}
14761476

1477+
let Predicates = [HasNEON, HasDotProd] in {
1478+
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
1479+
(v4i32 (UDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
1480+
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
1481+
(v4i32 (SDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
1482+
def : Pat<(v2i32 (partial_reduce_umla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
1483+
(v2i32 (UDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
1484+
def : Pat<(v2i32 (partial_reduce_smla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
1485+
(v2i32 (SDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
1486+
} // End HasNEON, HasDotProd
1487+
14771488
// ARMv8.6-A BFloat
14781489
let Predicates = [HasNEON, HasBF16] in {
14791490
defm BFDOT : SIMDThreeSameVectorBFDot<1, "bfdot">;

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
33
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
44
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
5-
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
5+
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM,CHECK-NEWLOWERING-I8MM
6+
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM,CHECK-NEWLOWERING-NOI8MM
67

78
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
89
; CHECK-DOT-LABEL: udot:
@@ -49,17 +50,20 @@ define <4 x i32> @udot_in_loop(ptr %p1, ptr %p2){
4950
; CHECK-NODOT-NEXT: mov x8, xzr
5051
; CHECK-NODOT-NEXT: .LBB1_1: // %vector.body
5152
; CHECK-NODOT-NEXT: // =>This Inner Loop Header: Depth=1
52-
; CHECK-NODOT-NEXT: ldr q2, [x0, x8]
53-
; CHECK-NODOT-NEXT: ldr q3, [x1, x8]
54-
; CHECK-NODOT-NEXT: mov v0.16b, v1.16b
53+
; CHECK-NODOT-NEXT: ldr q0, [x0, x8]
54+
; CHECK-NODOT-NEXT: ldr q2, [x1, x8]
5555
; CHECK-NODOT-NEXT: add x8, x8, #16
5656
; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
5757
; CHECK-NODOT-NEXT: umull2 v2.8h, v2.16b, v3.16b
5858
; CHECK-NODOT-NEXT: cmp x8, #16
59-
; CHECK-NODOT-NEXT: uaddw v1.4s, v1.4s, v4.4h
60-
; CHECK-NODOT-NEXT: uaddw2 v1.4s, v1.4s, v4.8h
61-
; CHECK-NODOT-NEXT: uaddw v1.4s, v1.4s, v2.4h
62-
; CHECK-NODOT-NEXT: uaddw2 v1.4s, v1.4s, v2.8h
59+
; CHECK-NODOT-NEXT: umull v3.8h, v0.8b, v2.8b
60+
; CHECK-NODOT-NEXT: umull2 v2.8h, v0.16b, v2.16b
61+
; CHECK-NODOT-NEXT: mov v0.16b, v1.16b
62+
; CHECK-NODOT-NEXT: ushll v1.4s, v2.4h, #0
63+
; CHECK-NODOT-NEXT: uaddw v4.4s, v0.4s, v3.4h
64+
; CHECK-NODOT-NEXT: uaddw2 v1.4s, v1.4s, v3.8h
65+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v4.4s, v2.8h
66+
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
6367
; CHECK-NODOT-NEXT: b.ne .LBB1_1
6468
; CHECK-NODOT-NEXT: // %bb.2: // %end
6569
; CHECK-NODOT-NEXT: ret
@@ -563,22 +567,6 @@ define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
563567
}
564568

565569
define <4 x i32> @udot_no_bin_op_in_loop(ptr %p){
566-
; CHECK-DOT-LABEL: udot_no_bin_op_in_loop:
567-
; CHECK-DOT: // %bb.0: // %entry
568-
; CHECK-DOT-NEXT: movi v1.2d, #0000000000000000
569-
; CHECK-DOT-NEXT: movi v2.16b, #1
570-
; CHECK-DOT-NEXT: mov x8, xzr
571-
; CHECK-DOT-NEXT: .LBB16_1: // %vector.body
572-
; CHECK-DOT-NEXT: // =>This Inner Loop Header: Depth=1
573-
; CHECK-DOT-NEXT: ldr q3, [x0, x8]
574-
; CHECK-DOT-NEXT: mov v0.16b, v1.16b
575-
; CHECK-DOT-NEXT: add x8, x8, #16
576-
; CHECK-DOT-NEXT: cmp x8, #16
577-
; CHECK-DOT-NEXT: udot v1.4s, v3.16b, v2.16b
578-
; CHECK-DOT-NEXT: b.ne .LBB16_1
579-
; CHECK-DOT-NEXT: // %bb.2: // %end
580-
; CHECK-DOT-NEXT: ret
581-
;
582570
; CHECK-NODOT-LABEL: udot_no_bin_op_in_loop:
583571
; CHECK-NODOT: // %bb.0: // %entry
584572
; CHECK-NODOT-NEXT: movi v1.2d, #0000000000000000
@@ -598,6 +586,66 @@ define <4 x i32> @udot_no_bin_op_in_loop(ptr %p){
598586
; CHECK-NODOT-NEXT: b.ne .LBB16_1
599587
; CHECK-NODOT-NEXT: // %bb.2: // %end
600588
; CHECK-NODOT-NEXT: ret
589+
;
590+
; CHECK-NEWLOWERING-I8MM-LABEL: udot_no_bin_op_in_loop:
591+
; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
592+
; CHECK-NEWLOWERING-I8MM-NEXT: adrp x8, .LCPI16_0
593+
; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
594+
; CHECK-NEWLOWERING-I8MM-NEXT: adrp x9, .LCPI16_2
595+
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q1, [x8, :lo12:.LCPI16_0]
596+
; CHECK-NEWLOWERING-I8MM-NEXT: adrp x8, .LCPI16_1
597+
; CHECK-NEWLOWERING-I8MM-NEXT: adrp x10, .LCPI16_3
598+
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q2, [x8, :lo12:.LCPI16_1]
599+
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x9, :lo12:.LCPI16_2]
600+
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q5, [x10, :lo12:.LCPI16_3]
601+
; CHECK-NEWLOWERING-I8MM-NEXT: mov x8, xzr
602+
; CHECK-NEWLOWERING-I8MM-NEXT: .LBB16_1: // %vector.body
603+
; CHECK-NEWLOWERING-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
604+
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q6, [x0, x8]
605+
; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v4.16b
606+
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
607+
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
608+
; CHECK-NEWLOWERING-I8MM-NEXT: tbl v7.16b, { v6.16b }, v2.16b
609+
; CHECK-NEWLOWERING-I8MM-NEXT: tbl v4.16b, { v6.16b }, v1.16b
610+
; CHECK-NEWLOWERING-I8MM-NEXT: tbl v16.16b, { v6.16b }, v3.16b
611+
; CHECK-NEWLOWERING-I8MM-NEXT: tbl v6.16b, { v6.16b }, v5.16b
612+
; CHECK-NEWLOWERING-I8MM-NEXT: add v7.4s, v0.4s, v7.4s
613+
; CHECK-NEWLOWERING-I8MM-NEXT: add v6.4s, v6.4s, v16.4s
614+
; CHECK-NEWLOWERING-I8MM-NEXT: add v4.4s, v4.4s, v7.4s
615+
; CHECK-NEWLOWERING-I8MM-NEXT: add v4.4s, v6.4s, v4.4s
616+
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB16_1
617+
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
618+
; CHECK-NEWLOWERING-I8MM-NEXT: ret
619+
;
620+
; CHECK-NEWLOWERING-NOI8MM-LABEL: udot_no_bin_op_in_loop:
621+
; CHECK-NEWLOWERING-NOI8MM: // %bb.0: // %entry
622+
; CHECK-NEWLOWERING-NOI8MM-NEXT: adrp x8, .LCPI16_0
623+
; CHECK-NEWLOWERING-NOI8MM-NEXT: movi v4.2d, #0000000000000000
624+
; CHECK-NEWLOWERING-NOI8MM-NEXT: adrp x9, .LCPI16_2
625+
; CHECK-NEWLOWERING-NOI8MM-NEXT: ldr q1, [x8, :lo12:.LCPI16_0]
626+
; CHECK-NEWLOWERING-NOI8MM-NEXT: adrp x8, .LCPI16_1
627+
; CHECK-NEWLOWERING-NOI8MM-NEXT: adrp x10, .LCPI16_3
628+
; CHECK-NEWLOWERING-NOI8MM-NEXT: ldr q2, [x8, :lo12:.LCPI16_1]
629+
; CHECK-NEWLOWERING-NOI8MM-NEXT: ldr q3, [x9, :lo12:.LCPI16_2]
630+
; CHECK-NEWLOWERING-NOI8MM-NEXT: ldr q5, [x10, :lo12:.LCPI16_3]
631+
; CHECK-NEWLOWERING-NOI8MM-NEXT: mov x8, xzr
632+
; CHECK-NEWLOWERING-NOI8MM-NEXT: .LBB16_1: // %vector.body
633+
; CHECK-NEWLOWERING-NOI8MM-NEXT: // =>This Inner Loop Header: Depth=1
634+
; CHECK-NEWLOWERING-NOI8MM-NEXT: ldr q6, [x0, x8]
635+
; CHECK-NEWLOWERING-NOI8MM-NEXT: mov v0.16b, v4.16b
636+
; CHECK-NEWLOWERING-NOI8MM-NEXT: add x8, x8, #16
637+
; CHECK-NEWLOWERING-NOI8MM-NEXT: cmp x8, #16
638+
; CHECK-NEWLOWERING-NOI8MM-NEXT: tbl v7.16b, { v6.16b }, v2.16b
639+
; CHECK-NEWLOWERING-NOI8MM-NEXT: tbl v4.16b, { v6.16b }, v1.16b
640+
; CHECK-NEWLOWERING-NOI8MM-NEXT: tbl v16.16b, { v6.16b }, v3.16b
641+
; CHECK-NEWLOWERING-NOI8MM-NEXT: tbl v6.16b, { v6.16b }, v5.16b
642+
; CHECK-NEWLOWERING-NOI8MM-NEXT: add v7.4s, v0.4s, v7.4s
643+
; CHECK-NEWLOWERING-NOI8MM-NEXT: add v6.4s, v6.4s, v16.4s
644+
; CHECK-NEWLOWERING-NOI8MM-NEXT: add v4.4s, v4.4s, v7.4s
645+
; CHECK-NEWLOWERING-NOI8MM-NEXT: add v4.4s, v6.4s, v4.4s
646+
; CHECK-NEWLOWERING-NOI8MM-NEXT: b.ne .LBB16_1
647+
; CHECK-NEWLOWERING-NOI8MM-NEXT: // %bb.2: // %end
648+
; CHECK-NEWLOWERING-NOI8MM-NEXT: ret
601649

602650
entry:
603651
br label %vector.body

0 commit comments

Comments
 (0)