Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -4238,6 +4238,13 @@ defm UDOT_ZZZ_HtoS : sve2p1_two_way_dot_vv<"udot", 0b1, int_aarch64_sve_udot_x2
defm SDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"sdot", 0b0, int_aarch64_sve_sdot_lane_x2>;
defm UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1, int_aarch64_sve_udot_lane_x2>;

let Predicates = [HasSVE2p1_or_SME2] in {
def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
(UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
(SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
} // End HasSVE2p1_or_SME2

defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>;
defm UQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>;
defm SQCVTUN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtun", 0b10, int_aarch64_sve_sqcvtun_x2>;
Expand Down
88 changes: 88 additions & 0 deletions llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s

define <vscale x 4 x i32> @udot_vl128(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: udot_vl128:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.s, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 4 x i32> @sdot_vl128(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: sdot_vl128:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 4 x i32> @udot_vl256(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) vscale_range(2,2) {
; CHECK-LABEL: udot_vl256:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.s, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 4 x i32> @sdot_vl256(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) vscale_range(2,2) {
; CHECK-LABEL: sdot_vl256:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <4 x i32> @fixed_udot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: fixed_udot_s_h:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: udot z0.s, z1.h, z2.h
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
entry:
%a.wide = zext <8 x i16> %a to <8 x i32>
%b.wide = zext <8 x i16> %b to <8 x i32>
%mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
ret <4 x i32> %partial.reduce
}

define <4 x i32> @fixed_sdot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: fixed_sdot_s_h:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
entry:
%a.wide = sext <8 x i16> %a to <8 x i32>
%b.wide = sext <8 x i16> %b to <8 x i32>
%mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
ret <4 x i32> %partial.reduce
}