Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -4238,6 +4238,13 @@ defm UDOT_ZZZ_HtoS : sve2p1_two_way_dot_vv<"udot", 0b1, int_aarch64_sve_udot_x2
defm SDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"sdot", 0b0, int_aarch64_sve_sdot_lane_x2>;
defm UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1, int_aarch64_sve_udot_lane_x2>;

let Predicates = [HasSVE2p1_or_SME2] in {
def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
(UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
(SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
} // End HasSVE2p1_or_SME2

defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>;
defm UQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>;
defm SQCVTUN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtun", 0b10, int_aarch64_sve_sqcvtun_x2>;
Expand Down
105 changes: 105 additions & 0 deletions llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -force-streaming < %s | FileCheck %s

define <vscale x 4 x i32> @udot_vl128(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: udot_vl128:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.s, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 4 x i32> @sdot_vl128(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: sdot_vl128:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define void @udot_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
; CHECK-LABEL: udot_vl256:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr z0, [x0]
; CHECK-NEXT: ldr z1, [x1]
; CHECK-NEXT: ldr z2, [x2]
; CHECK-NEXT: udot z0.s, z1.h, z2.h
; CHECK-NEXT: str z0, [x0]
; CHECK-NEXT: ret
entry:
%acc = load <8 x i32>, ptr %accptr
%a = load <16 x i16>, ptr %aptr
%b = load <16 x i16>, ptr %bptr
%a.wide = zext <16 x i16> %a to <16 x i32>
%b.wide = zext <16 x i16> %b to <16 x i32>
%mult = mul nuw nsw <16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
store <8 x i32> %partial.reduce, ptr %accptr
ret void
}

define void @sdot_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
; CHECK-LABEL: sdot_vl256:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr z0, [x0]
; CHECK-NEXT: ldr z1, [x1]
; CHECK-NEXT: ldr z2, [x2]
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
; CHECK-NEXT: str z0, [x0]
; CHECK-NEXT: ret
entry:
%acc = load <8 x i32>, ptr %accptr
%a = load <16 x i16>, ptr %aptr
%b = load <16 x i16>, ptr %bptr
%a.wide = sext <16 x i16> %a to <16 x i32>
%b.wide = sext <16 x i16> %b to <16 x i32>
%mult = mul nuw nsw <16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
store <8 x i32> %partial.reduce, ptr %accptr
ret void
}

define <4 x i32> @fixed_udot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: fixed_udot_s_h:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: udot z0.s, z1.h, z2.h
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
entry:
%a.wide = zext <8 x i16> %a to <8 x i32>
%b.wide = zext <8 x i16> %b to <8 x i32>
%mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
ret <4 x i32> %partial.reduce
}

define <4 x i32> @fixed_sdot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: fixed_sdot_s_h:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
entry:
%a.wide = sext <8 x i16> %a to <8 x i32>
%b.wide = sext <8 x i16> %b to <8 x i32>
%mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
ret <4 x i32> %partial.reduce
}