diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 766599d567efd..64ce3f986e9eb 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2272,6 +2272,17 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { setPartialReduceMLAAction(MLAOps, VT, MVT::getVectorVT(MVT::i8, NumElts * 2), Custom); } + + if (Subtarget->hasMatMulInt8()) { + if (VT.getVectorElementType() == MVT::i32) + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, + MVT::getVectorVT(MVT::i8, NumElts * 4), + Custom); + else if (VT.getVectorElementType() == MVT::i64) + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, + MVT::getVectorVT(MVT::i8, NumElts * 8), + Custom); + } } // Lower fixed length vector operations to scalable equivalents. diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll index 79d766d1b9908..af813ff16a202 100644 --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll @@ -1,7 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON -; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE -; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME +; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON +; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE +; RUN: llc -mattr=+sme,+i8mm -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME target triple = "aarch64" @@ -407,6 +407,154 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) { ret <4 x i32> %partial.reduce } +define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) { +; COMMON-LABEL: four_way_i8_i32_vl128_usdot: +; COMMON: // %bb.0: +; COMMON-NEXT: ldr q0, [x0] +; COMMON-NEXT: ldr q1, [x1] +; COMMON-NEXT: ldr q2, [x2] +; COMMON-NEXT: usdot v0.4s, v1.16b, v2.16b +; COMMON-NEXT: ret +; +; SME-LABEL: four_way_i8_i32_vl128_usdot: +; SME: // %bb.0: +; SME-NEXT: ldr q0, [x0] +; SME-NEXT: ldr q1, [x1] +; SME-NEXT: ldr q2, [x2] +; SME-NEXT: usdot z0.s, z1.b, z2.b +; SME-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SME-NEXT: ret + %acc = load <4 x i32>, ptr %accptr + %u = load <16 x i8>, ptr %uptr + %s = load <16 x i8>, ptr %sptr + %u.wide = zext <16 x i8> %u to <16 x i32> + %s.wide = sext <16 x i8> %s to <16 x i32> + %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide + %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult) + ret <4 x i32> %partial.reduce +} + +define <4 x i32> @four_way_i8_i32_vl128_sudot(ptr %accptr, ptr %uptr, ptr %sptr) { +; COMMON-LABEL: four_way_i8_i32_vl128_sudot: +; COMMON: // %bb.0: +; COMMON-NEXT: ldr q0, [x0] +; COMMON-NEXT: ldr q1, [x1] +; COMMON-NEXT: ldr q2, [x2] +; COMMON-NEXT: usdot v0.4s, v2.16b, v1.16b +; COMMON-NEXT: ret +; +; SME-LABEL: four_way_i8_i32_vl128_sudot: +; SME: // %bb.0: +; SME-NEXT: ldr q0, [x0] +; SME-NEXT: ldr q1, [x1] +; SME-NEXT: ldr q2, [x2] +; SME-NEXT: usdot z0.s, z2.b, z1.b +; SME-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SME-NEXT: ret + %acc = load <4 x i32>, ptr %accptr + %u = load <16 x i8>, ptr %uptr + %s = load <16 x i8>, ptr %sptr + %u.wide = sext <16 x i8> %u to <16 x i32> + %s.wide = zext <16 x i8> %s to <16 x i32> + %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide + %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult) + ret <4 x i32> %partial.reduce +} + +define <2 x i64> @four_way_i8_i64_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) { +; NEON-LABEL: four_way_i8_i64_vl128_usdot: +; NEON: // %bb.0: +; NEON-NEXT: movi v0.2d, #0000000000000000 +; NEON-NEXT: ldr q1, [x1] +; NEON-NEXT: ldr q2, [x2] +; NEON-NEXT: usdot v0.4s, v1.16b, v2.16b +; NEON-NEXT: ldr q1, [x0] +; NEON-NEXT: saddw v1.2d, v1.2d, v0.2s +; NEON-NEXT: saddw2 v0.2d, v1.2d, v0.4s +; NEON-NEXT: ret +; +; SVE-LABEL: four_way_i8_i64_vl128_usdot: +; SVE: // %bb.0: +; SVE-NEXT: movi v0.2d, #0000000000000000 +; SVE-NEXT: ldr q1, [x1] +; SVE-NEXT: ldr q2, [x2] +; SVE-NEXT: usdot z0.s, z1.b, z2.b +; SVE-NEXT: ldr q2, [x0] +; SVE-NEXT: sunpklo z1.d, z0.s +; SVE-NEXT: sunpkhi z0.d, z0.s +; SVE-NEXT: add z1.d, z2.d, z1.d +; SVE-NEXT: add z0.d, z1.d, z0.d +; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SVE-NEXT: ret +; +; SME-LABEL: four_way_i8_i64_vl128_usdot: +; SME: // %bb.0: +; SME-NEXT: mov z0.s, #0 // =0x0 +; SME-NEXT: ldr q1, [x1] +; SME-NEXT: ldr q2, [x2] +; SME-NEXT: usdot z0.s, z1.b, z2.b +; SME-NEXT: ldr q1, [x0] +; SME-NEXT: saddwb z1.d, z1.d, z0.s +; SME-NEXT: saddwt z0.d, z1.d, z0.s +; SME-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SME-NEXT: ret + %acc = load <2 x i64>, ptr %accptr + %u = load <16 x i8>, ptr %uptr + %s = load <16 x i8>, ptr %sptr + %u.wide = zext <16 x i8> %u to <16 x i64> + %s.wide = sext <16 x i8> %s to <16 x i64> + %mult = mul nuw nsw <16 x i64> %s.wide, %u.wide + %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <16 x i64> %mult) + ret <2 x i64> %partial.reduce +} + +define <2 x i64> @four_way_i16_i64_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) { +; COMMON-LABEL: four_way_i16_i64_vl128_usdot: +; COMMON: // %bb.0: +; COMMON-NEXT: ldr q1, [x1] +; COMMON-NEXT: ldr q2, [x2] +; COMMON-NEXT: ldr q0, [x0] +; COMMON-NEXT: ushll v3.4s, v1.4h, #0 +; COMMON-NEXT: sshll v4.4s, v2.4h, #0 +; COMMON-NEXT: ushll2 v1.4s, v1.8h, #0 +; COMMON-NEXT: sshll2 v2.4s, v2.8h, #0 +; COMMON-NEXT: smlal v0.2d, v4.2s, v3.2s +; COMMON-NEXT: smlal2 v0.2d, v4.4s, v3.4s +; COMMON-NEXT: smlal v0.2d, v2.2s, v1.2s +; COMMON-NEXT: smlal2 v0.2d, v2.4s, v1.4s +; COMMON-NEXT: ret +; +; SME-LABEL: four_way_i16_i64_vl128_usdot: +; SME: // %bb.0: +; SME-NEXT: ptrue p0.d, vl2 +; SME-NEXT: ldr q2, [x0] +; SME-NEXT: mov x8, #2 // =0x2 +; SME-NEXT: ld1h { z0.d }, p0/z, [x1] +; SME-NEXT: ld1sh { z1.d }, p0/z, [x2] +; SME-NEXT: mad z0.d, p0/m, z1.d, z2.d +; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1] +; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1] +; SME-NEXT: mov x8, #4 // =0x4 +; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d +; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1] +; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1] +; SME-NEXT: mov x8, #6 // =0x6 +; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d +; SME-NEXT: ld1h { z1.d }, p0/z, [x1, x8, lsl #1] +; SME-NEXT: ld1sh { z2.d }, p0/z, [x2, x8, lsl #1] +; SME-NEXT: mla z0.d, p0/m, z2.d, z1.d +; SME-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SME-NEXT: ret + %acc = load <2 x i64>, ptr %accptr + %u = load <8 x i16>, ptr %uptr + %s = load <8 x i16>, ptr %sptr + %u.wide = zext <8 x i16> %u to <8 x i64> + %s.wide = sext <8 x i16> %s to <8 x i64> + %mult = mul nuw nsw <8 x i64> %s.wide, %u.wide + %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <8 x i64> %mult) + ret <2 x i64> %partial.reduce +} + define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) { ; ; COMMON-LABEL: four_way_i8_i32_vl128_double_width: @@ -438,6 +586,37 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr ret <8 x i32> %partial.reduce } +define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) { +; +; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot: +; COMMON: // %bb.0: +; COMMON-NEXT: ldp q0, q1, [x0] +; COMMON-NEXT: ldp q3, q2, [x1] +; COMMON-NEXT: ldp q5, q4, [x2] +; COMMON-NEXT: usdot v0.4s, v3.16b, v5.16b +; COMMON-NEXT: usdot v1.4s, v2.16b, v4.16b +; COMMON-NEXT: ret +; +; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot: +; SME: // %bb.0: +; SME-NEXT: ldp q0, q1, [x0] +; SME-NEXT: ldp q3, q2, [x1] +; SME-NEXT: ldp q5, q4, [x2] +; SME-NEXT: usdot z0.s, z3.b, z5.b +; SME-NEXT: usdot z1.s, z2.b, z4.b +; SME-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SME-NEXT: // kill: def $q1 killed $q1 killed $z1 +; SME-NEXT: ret + %acc = load <8 x i32>, ptr %accptr + %u = load <32 x i8>, ptr %uptr + %s = load <32 x i8>, ptr %sptr + %u.wide = zext <32 x i8> %u to <32 x i32> + %s.wide = sext <32 x i8> %s to <32 x i32> + %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide + %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult) + ret <8 x i32> %partial.reduce +} + define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) { ; ; @@ -483,6 +662,51 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal ret <8 x i32> %partial.reduce } +define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) { +; +; +; NEON-LABEL: four_way_i8_i32_vl256_usdot: +; NEON: // %bb.0: +; NEON-NEXT: ldp q0, q1, [x0] +; NEON-NEXT: ldp q3, q2, [x1] +; NEON-NEXT: ldp q5, q4, [x2] +; NEON-NEXT: usdot v0.4s, v3.16b, v5.16b +; NEON-NEXT: usdot v1.4s, v2.16b, v4.16b +; NEON-NEXT: ret +; +; SVE-LABEL: four_way_i8_i32_vl256_usdot: +; SVE: // %bb.0: +; SVE-NEXT: ldr z0, [x0] +; SVE-NEXT: ldr z1, [x1] +; SVE-NEXT: ldr z2, [x2] +; SVE-NEXT: usdot z0.s, z1.b, z2.b +; SVE-NEXT: mov z1.d, z0.d +; SVE-NEXT: ext z1.b, z1.b, z0.b, #16 +; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1 +; SVE-NEXT: ret +; +; SME-LABEL: four_way_i8_i32_vl256_usdot: +; SME: // %bb.0: +; SME-NEXT: ldr z0, [x0] +; SME-NEXT: ldr z1, [x1] +; SME-NEXT: ldr z2, [x2] +; SME-NEXT: usdot z0.s, z1.b, z2.b +; SME-NEXT: mov z1.d, z0.d +; SME-NEXT: ext z1.b, z1.b, z0.b, #16 +; SME-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SME-NEXT: // kill: def $q1 killed $q1 killed $z1 +; SME-NEXT: ret + %acc = load <8 x i32>, ptr %accptr + %u = load <32 x i8>, ptr %uptr + %s = load <32 x i8>, ptr %sptr + %u.wide = zext <32 x i8> %u to <32 x i32> + %s.wide = sext <32 x i8> %s to <32 x i32> + %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide + %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult) + ret <8 x i32> %partial.reduce +} + ; ; Four-way dot (i16 -> i64) ;