Skip to content

Commit 8b1ac34

Browse files
committed
[AArch64] Add fixed-length SVE USDOT support
1 parent 6fb2a80 commit 8b1ac34

File tree

2 files changed

+165
-2
lines changed

2 files changed

+165
-2
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22722272
setPartialReduceMLAAction(MLAOps, VT,
22732273
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
22742274
}
2275+
2276+
if (Subtarget->hasMatMulInt8()) {
2277+
if (VT.getVectorElementType() == MVT::i32)
2278+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2279+
else if (VT.getVectorElementType() == MVT::i64)
2280+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2281+
}
22752282
}
22762283

22772284
// Lower fixed length vector operations to scalable equivalents.

llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
3-
; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
2+
; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
3+
; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
44
; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
55

66
target triple = "aarch64"
@@ -407,6 +407,46 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
407407
ret <4 x i32> %partial.reduce
408408
}
409409

410+
define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
411+
; COMMON-LABEL: four_way_i8_i32_vl128_usdot:
412+
; COMMON: // %bb.0:
413+
; COMMON-NEXT: ldr q0, [x0]
414+
; COMMON-NEXT: ldr q1, [x1]
415+
; COMMON-NEXT: ldr q2, [x2]
416+
; COMMON-NEXT: usdot v0.4s, v1.16b, v2.16b
417+
; COMMON-NEXT: ret
418+
;
419+
; SME-LABEL: four_way_i8_i32_vl128_usdot:
420+
; SME: // %bb.0:
421+
; SME-NEXT: ptrue p0.s, vl4
422+
; SME-NEXT: ldr q2, [x0]
423+
; SME-NEXT: mov w8, #4 // =0x4
424+
; SME-NEXT: ld1b { z0.s }, p0/z, [x1]
425+
; SME-NEXT: ld1sb { z1.s }, p0/z, [x2]
426+
; SME-NEXT: mad z0.s, p0/m, z1.s, z2.s
427+
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
428+
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
429+
; SME-NEXT: mov w8, #8 // =0x8
430+
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
431+
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
432+
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
433+
; SME-NEXT: mov w8, #12 // =0xc
434+
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
435+
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
436+
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
437+
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
438+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
439+
; SME-NEXT: ret
440+
%acc = load <4 x i32>, ptr %accptr
441+
%u = load <16 x i8>, ptr %uptr
442+
%s = load <16 x i8>, ptr %sptr
443+
%u.wide = zext <16 x i8> %u to <16 x i32>
444+
%s.wide = sext <16 x i8> %s to <16 x i32>
445+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
446+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
447+
ret <4 x i32> %partial.reduce
448+
}
449+
410450
define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
411451
;
412452
; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
@@ -438,6 +478,67 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr
438478
ret <8 x i32> %partial.reduce
439479
}
440480

481+
define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
482+
;
483+
; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot:
484+
; COMMON: // %bb.0:
485+
; COMMON-NEXT: ldp q0, q1, [x0]
486+
; COMMON-NEXT: ldp q3, q2, [x1]
487+
; COMMON-NEXT: ldp q5, q4, [x2]
488+
; COMMON-NEXT: usdot v0.4s, v3.16b, v5.16b
489+
; COMMON-NEXT: usdot v1.4s, v2.16b, v4.16b
490+
; COMMON-NEXT: ret
491+
;
492+
; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
493+
; SME: // %bb.0:
494+
; SME-NEXT: ptrue p0.s, vl4
495+
; SME-NEXT: mov w8, #16 // =0x10
496+
; SME-NEXT: mov w9, #4 // =0x4
497+
; SME-NEXT: ldp q5, q4, [x0]
498+
; SME-NEXT: ld1b { z0.s }, p0/z, [x1, x8]
499+
; SME-NEXT: ld1b { z1.s }, p0/z, [x1]
500+
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
501+
; SME-NEXT: ld1sb { z3.s }, p0/z, [x2]
502+
; SME-NEXT: mov w8, #20 // =0x14
503+
; SME-NEXT: ld1b { z6.s }, p0/z, [x1, x8]
504+
; SME-NEXT: mad z0.s, p0/m, z2.s, z4.s
505+
; SME-NEXT: ld1b { z2.s }, p0/z, [x1, x9]
506+
; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
507+
; SME-NEXT: mad z1.s, p0/m, z3.s, z5.s
508+
; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
509+
; SME-NEXT: mov w8, #24 // =0x18
510+
; SME-NEXT: mov w9, #8 // =0x8
511+
; SME-NEXT: ld1b { z5.s }, p0/z, [x1, x8]
512+
; SME-NEXT: mla z0.s, p0/m, z3.s, z6.s
513+
; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
514+
; SME-NEXT: mov w8, #28 // =0x1c
515+
; SME-NEXT: mla z1.s, p0/m, z4.s, z2.s
516+
; SME-NEXT: ld1b { z2.s }, p0/z, [x1, x9]
517+
; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
518+
; SME-NEXT: mov w9, #12 // =0xc
519+
; SME-NEXT: ld1b { z6.s }, p0/z, [x1, x8]
520+
; SME-NEXT: mla z1.s, p0/m, z4.s, z2.s
521+
; SME-NEXT: movprfx z2, z0
522+
; SME-NEXT: mla z2.s, p0/m, z3.s, z5.s
523+
; SME-NEXT: ld1b { z0.s }, p0/z, [x1, x9]
524+
; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
525+
; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
526+
; SME-NEXT: mad z0.s, p0/m, z4.s, z1.s
527+
; SME-NEXT: movprfx z1, z2
528+
; SME-NEXT: mla z1.s, p0/m, z3.s, z6.s
529+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
530+
; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
531+
; SME-NEXT: ret
532+
%acc = load <8 x i32>, ptr %accptr
533+
%u = load <32 x i8>, ptr %uptr
534+
%s = load <32 x i8>, ptr %sptr
535+
%u.wide = zext <32 x i8> %u to <32 x i32>
536+
%s.wide = sext <32 x i8> %s to <32 x i32>
537+
%mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
538+
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
539+
ret <8 x i32> %partial.reduce
540+
}
541+
441542
define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
442543
;
443544
;
@@ -483,6 +584,61 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal
483584
ret <8 x i32> %partial.reduce
484585
}
485586

587+
define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
588+
;
589+
;
590+
; NEON-LABEL: four_way_i8_i32_vl256_usdot:
591+
; NEON: // %bb.0:
592+
; NEON-NEXT: ldp q0, q1, [x0]
593+
; NEON-NEXT: ldp q3, q2, [x1]
594+
; NEON-NEXT: ldp q5, q4, [x2]
595+
; NEON-NEXT: usdot v0.4s, v3.16b, v5.16b
596+
; NEON-NEXT: usdot v1.4s, v2.16b, v4.16b
597+
; NEON-NEXT: ret
598+
;
599+
; SVE-LABEL: four_way_i8_i32_vl256_usdot:
600+
; SVE: // %bb.0:
601+
; SVE-NEXT: ldr z0, [x0]
602+
; SVE-NEXT: ldr z1, [x1]
603+
; SVE-NEXT: ldr z2, [x2]
604+
; SVE-NEXT: usdot z0.s, z1.b, z2.b
605+
; SVE-NEXT: mov z1.d, z0.d
606+
; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
607+
; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
608+
; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
609+
; SVE-NEXT: ret
610+
;
611+
; SME-LABEL: four_way_i8_i32_vl256_usdot:
612+
; SME: // %bb.0:
613+
; SME-NEXT: ptrue p0.s
614+
; SME-NEXT: ldr z0, [x0]
615+
; SME-NEXT: ld1b { z1.s }, p0/z, [x1]
616+
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2]
617+
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
618+
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #1, mul vl]
619+
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #1, mul vl]
620+
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
621+
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #2, mul vl]
622+
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #2, mul vl]
623+
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
624+
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #3, mul vl]
625+
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #3, mul vl]
626+
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
627+
; SME-NEXT: mov z1.d, z0.d
628+
; SME-NEXT: ext z1.b, z1.b, z0.b, #16
629+
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
630+
; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
631+
; SME-NEXT: ret
632+
%acc = load <8 x i32>, ptr %accptr
633+
%u = load <32 x i8>, ptr %uptr
634+
%s = load <32 x i8>, ptr %sptr
635+
%u.wide = zext <32 x i8> %u to <32 x i32>
636+
%s.wide = sext <32 x i8> %s to <32 x i32>
637+
%mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
638+
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
639+
ret <8 x i32> %partial.reduce
640+
}
641+
486642
;
487643
; Four-way dot (i16 -> i64)
488644
;

0 commit comments

Comments
 (0)