Skip to content

Commit 4f3eb0e

Browse files
sdesmalen-armsvkeerthy
authored andcommitted
[AArch64] Improve codegen for partial.reduce.add v16i8 -> v2i32 (#161833)
Rather than expanding, we can handle this case natively by widening the accumulator.
1 parent cd07251 commit 4f3eb0e

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,6 +1458,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14581458

14591459
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
14601460
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
1461+
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v16i8, Custom);
14611462
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
14621463

14631464
if (Subtarget->hasMatMulInt8()) {
@@ -30769,6 +30770,17 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
3076930770
ResultVT.isFixedLengthVector() &&
3077030771
useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
3077130772

30773+
// We can handle this case natively by accumulating into a wider
30774+
// zero-padded vector.
30775+
if (!ConvertToScalable && ResultVT == MVT::v2i32 && OpVT == MVT::v16i8) {
30776+
SDValue ZeroVec = DAG.getConstant(0, DL, MVT::v4i32);
30777+
SDValue WideAcc = DAG.getInsertSubvector(DL, ZeroVec, Acc, 0);
30778+
SDValue Wide =
30779+
DAG.getNode(Op.getOpcode(), DL, MVT::v4i32, WideAcc, LHS, RHS);
30780+
SDValue Reduced = DAG.getNode(AArch64ISD::ADDP, DL, MVT::v4i32, Wide, Wide);
30781+
return DAG.getExtractSubvector(DL, MVT::v2i32, Reduced, 0);
30782+
}
30783+
3077230784
if (ConvertToScalable) {
3077330785
ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
3077430786
OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,3 +1451,52 @@ define <4 x i32> @partial_reduce_shl_zext_non_const_rhs(<16 x i8> %l, <4 x i32>
14511451
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
14521452
ret <4 x i32> %red
14531453
}
1454+
1455+
define <2 x i32> @udot_v16i8tov2i32(<2 x i32> %acc, <16 x i8> %input) {
1456+
; CHECK-NODOT-LABEL: udot_v16i8tov2i32:
1457+
; CHECK-NODOT: // %bb.0: // %entry
1458+
; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
1459+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
1460+
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
1461+
; CHECK-NODOT-NEXT: ushll v3.4s, v2.4h, #0
1462+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
1463+
; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #0
1464+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
1465+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
1466+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
1467+
; CHECK-NODOT-NEXT: ext v3.16b, v4.16b, v4.16b, #8
1468+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
1469+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
1470+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
1471+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
1472+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
1473+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
1474+
; CHECK-NODOT-NEXT: ext v1.16b, v1.16b, v1.16b, #8
1475+
; CHECK-NODOT-NEXT: add v0.2s, v2.2s, v0.2s
1476+
; CHECK-NODOT-NEXT: ext v2.16b, v3.16b, v3.16b, #8
1477+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
1478+
; CHECK-NODOT-NEXT: add v0.2s, v2.2s, v0.2s
1479+
; CHECK-NODOT-NEXT: ret
1480+
;
1481+
; CHECK-DOT-LABEL: udot_v16i8tov2i32:
1482+
; CHECK-DOT: // %bb.0: // %entry
1483+
; CHECK-DOT-NEXT: movi v2.16b, #1
1484+
; CHECK-DOT-NEXT: fmov d0, d0
1485+
; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
1486+
; CHECK-DOT-NEXT: addp v0.4s, v0.4s, v0.4s
1487+
; CHECK-DOT-NEXT: // kill: def $d0 killed $d0 killed $q0
1488+
; CHECK-DOT-NEXT: ret
1489+
;
1490+
; CHECK-DOT-I8MM-LABEL: udot_v16i8tov2i32:
1491+
; CHECK-DOT-I8MM: // %bb.0: // %entry
1492+
; CHECK-DOT-I8MM-NEXT: movi v2.16b, #1
1493+
; CHECK-DOT-I8MM-NEXT: fmov d0, d0
1494+
; CHECK-DOT-I8MM-NEXT: udot v0.4s, v1.16b, v2.16b
1495+
; CHECK-DOT-I8MM-NEXT: addp v0.4s, v0.4s, v0.4s
1496+
; CHECK-DOT-I8MM-NEXT: // kill: def $d0 killed $d0 killed $q0
1497+
; CHECK-DOT-I8MM-NEXT: ret
1498+
entry:
1499+
%input.wide = zext <16 x i8> %input to <16 x i32>
1500+
%partial.reduce = tail call <2 x i32> @llvm.vector.partial.reduce.add(<2 x i32> %acc, <16 x i32> %input.wide)
1501+
ret <2 x i32> %partial.reduce
1502+
}

0 commit comments

Comments
 (0)