Skip to content

Commit 0619820

Browse files
[AArch64][SVE] Add dot product codegen for partial reductions with
no binary operation on input Add codegen for when the input type has 4 times as many elements as the output type and the input to the partial reduction does not have a binary operation performed on it.
1 parent 8dc23ef commit 0619820

File tree

2 files changed

+238
-0
lines changed

2 files changed

+238
-0
lines changed

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,166 @@ entry:
367367
ret <4 x i64> %partial.reduce
368368
}
369369

370+
define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
371+
; CHECK-DOT-LABEL: udot_no_bin_op:
372+
; CHECK-DOT: // %bb.0:
373+
; CHECK-DOT-NEXT: movi v2.16b, #1
374+
; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
375+
; CHECK-DOT-NEXT: ret
376+
;
377+
; CHECK-NODOT-LABEL: udot_no_bin_op:
378+
; CHECK-NODOT: // %bb.0:
379+
; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
380+
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
381+
; CHECK-NODOT-NEXT: ushll v3.4s, v1.4h, #0
382+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
383+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v3.4s, v2.8h
384+
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
385+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
386+
; CHECK-NODOT-NEXT: ret
387+
%a.wide = zext <16 x i8> %a to <16 x i32>
388+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
389+
ret <4 x i32> %partial.reduce
390+
}
391+
392+
define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
393+
; CHECK-DOT-LABEL: sdot_no_bin_op:
394+
; CHECK-DOT: // %bb.0:
395+
; CHECK-DOT-NEXT: movi v2.16b, #1
396+
; CHECK-DOT-NEXT: sdot v0.4s, v1.16b, v2.16b
397+
; CHECK-DOT-NEXT: ret
398+
;
399+
; CHECK-NODOT-LABEL: sdot_no_bin_op:
400+
; CHECK-NODOT: // %bb.0:
401+
; CHECK-NODOT-NEXT: sshll v2.8h, v1.8b, #0
402+
; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
403+
; CHECK-NODOT-NEXT: sshll v3.4s, v1.4h, #0
404+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v2.4h
405+
; CHECK-NODOT-NEXT: saddw2 v2.4s, v3.4s, v2.8h
406+
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
407+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
408+
; CHECK-NODOT-NEXT: ret
409+
%a.wide = sext <16 x i8> %a to <16 x i32>
410+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
411+
ret <4 x i32> %partial.reduce
412+
}
413+
414+
define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
415+
; CHECK-DOT-LABEL: udot_no_bin_op_narrow:
416+
; CHECK-DOT: // %bb.0:
417+
; CHECK-DOT-NEXT: movi v2.8b, #1
418+
; CHECK-DOT-NEXT: udot v0.2s, v1.8b, v2.8b
419+
; CHECK-DOT-NEXT: ret
420+
;
421+
; CHECK-NODOT-LABEL: udot_no_bin_op_narrow:
422+
; CHECK-NODOT: // %bb.0:
423+
; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
424+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
425+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
426+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
427+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
428+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
429+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
430+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
431+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
432+
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
433+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
434+
; CHECK-NODOT-NEXT: ret
435+
%a.wide = zext <8 x i8> %a to <8 x i32>
436+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
437+
ret <2 x i32> %partial.reduce
438+
}
439+
440+
define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
441+
; CHECK-DOT-LABEL: sdot_no_bin_op_narrow:
442+
; CHECK-DOT: // %bb.0:
443+
; CHECK-DOT-NEXT: movi v2.8b, #1
444+
; CHECK-DOT-NEXT: sdot v0.2s, v1.8b, v2.8b
445+
; CHECK-DOT-NEXT: ret
446+
;
447+
; CHECK-NODOT-LABEL: sdot_no_bin_op_narrow:
448+
; CHECK-NODOT: // %bb.0:
449+
; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
450+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
451+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
452+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
453+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
454+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
455+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
456+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
457+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
458+
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
459+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
460+
; CHECK-NODOT-NEXT: ret
461+
%a.wide = sext <8 x i8> %a to <8 x i32>
462+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
463+
ret <2 x i32> %partial.reduce
464+
}
465+
466+
define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
467+
; CHECK-DOT-LABEL: udot_no_bin_op_8to64:
468+
; CHECK-DOT: // %bb.0:
469+
; CHECK-DOT-NEXT: movi v3.16b, #1
470+
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
471+
; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
472+
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
473+
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
474+
; CHECK-DOT-NEXT: ret
475+
;
476+
; CHECK-NODOT-LABEL: udot_no_bin_op_8to64:
477+
; CHECK-NODOT: // %bb.0:
478+
; CHECK-NODOT-NEXT: ushll v3.8h, v2.8b, #0
479+
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
480+
; CHECK-NODOT-NEXT: ushll v4.4s, v3.4h, #0
481+
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
482+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v3.8h, #0
483+
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
484+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v4.4s
485+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v4.2s
486+
; CHECK-NODOT-NEXT: uaddl2 v4.2d, v3.4s, v5.4s
487+
; CHECK-NODOT-NEXT: uaddl v3.2d, v3.2s, v5.2s
488+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
489+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
490+
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
491+
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
492+
; CHECK-NODOT-NEXT: ret
493+
%a.wide = zext <16 x i8> %a to <16 x i64>
494+
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
495+
ret <4 x i64> %partial.reduce
496+
}
497+
498+
define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
499+
; CHECK-DOT-LABEL: sdot_no_bin_op_8to64:
500+
; CHECK-DOT: // %bb.0:
501+
; CHECK-DOT-NEXT: movi v3.16b, #1
502+
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
503+
; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
504+
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
505+
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
506+
; CHECK-DOT-NEXT: ret
507+
;
508+
; CHECK-NODOT-LABEL: sdot_no_bin_op_8to64:
509+
; CHECK-NODOT: // %bb.0:
510+
; CHECK-NODOT-NEXT: sshll v3.8h, v2.8b, #0
511+
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
512+
; CHECK-NODOT-NEXT: sshll v4.4s, v3.4h, #0
513+
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
514+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v3.8h, #0
515+
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
516+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
517+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v4.2s
518+
; CHECK-NODOT-NEXT: saddl2 v4.2d, v3.4s, v5.4s
519+
; CHECK-NODOT-NEXT: saddl v3.2d, v3.2s, v5.2s
520+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
521+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
522+
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
523+
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
524+
; CHECK-NODOT-NEXT: ret
525+
%a.wide = sext <16 x i8> %a to <16 x i64>
526+
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
527+
ret <4 x i64> %partial.reduce
528+
}
529+
370530
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
371531
; CHECK-LABEL: not_udot:
372532
; CHECK: // %bb.0:

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,84 @@ entry:
316316
ret <vscale x 4 x i64> %partial.reduce
317317
}
318318

319+
define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
320+
; CHECK-LABEL: udot_no_bin_op:
321+
; CHECK: // %bb.0:
322+
; CHECK-NEXT: mov z2.b, #1 // =0x1
323+
; CHECK-NEXT: udot z0.s, z1.b, z2.b
324+
; CHECK-NEXT: ret
325+
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
326+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
327+
ret <vscale x 4 x i32> %partial.reduce
328+
}
329+
330+
define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
331+
; CHECK-LABEL: sdot_no_bin_op:
332+
; CHECK: // %bb.0:
333+
; CHECK-NEXT: mov z2.b, #1 // =0x1
334+
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
335+
; CHECK-NEXT: ret
336+
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
337+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
338+
ret <vscale x 4 x i32> %partial.reduce
339+
}
340+
341+
define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
342+
; CHECK-LABEL: udot_no_bin_op_wide:
343+
; CHECK: // %bb.0: // %entry
344+
; CHECK-NEXT: mov z2.h, #1 // =0x1
345+
; CHECK-NEXT: udot z0.d, z1.h, z2.h
346+
; CHECK-NEXT: ret
347+
entry:
348+
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
349+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
350+
ret <vscale x 2 x i64> %partial.reduce
351+
}
352+
353+
define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
354+
; CHECK-LABEL: sdot_no_bin_op_wide:
355+
; CHECK: // %bb.0: // %entry
356+
; CHECK-NEXT: mov z2.h, #1 // =0x1
357+
; CHECK-NEXT: sdot z0.d, z1.h, z2.h
358+
; CHECK-NEXT: ret
359+
entry:
360+
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
361+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
362+
ret <vscale x 2 x i64> %partial.reduce
363+
}
364+
365+
define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
366+
; CHECK-LABEL: udot_no_bin_op_8to64:
367+
; CHECK: // %bb.0:
368+
; CHECK-NEXT: mov z3.b, #1 // =0x1
369+
; CHECK-NEXT: mov z4.s, #0 // =0x0
370+
; CHECK-NEXT: udot z4.s, z2.b, z3.b
371+
; CHECK-NEXT: sunpklo z2.d, z4.s
372+
; CHECK-NEXT: sunpkhi z3.d, z4.s
373+
; CHECK-NEXT: add z0.d, z0.d, z2.d
374+
; CHECK-NEXT: add z1.d, z1.d, z3.d
375+
; CHECK-NEXT: ret
376+
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
377+
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
378+
ret <vscale x 4 x i64> %partial.reduce
379+
}
380+
381+
define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
382+
; CHECK-LABEL: sdot_no_bin_op_8to64:
383+
; CHECK: // %bb.0:
384+
; CHECK-NEXT: mov z3.b, #1 // =0x1
385+
; CHECK-NEXT: mov z4.s, #0 // =0x0
386+
; CHECK-NEXT: sdot z4.s, z2.b, z3.b
387+
; CHECK-NEXT: sunpklo z2.d, z4.s
388+
; CHECK-NEXT: sunpkhi z3.d, z4.s
389+
; CHECK-NEXT: add z0.d, z0.d, z2.d
390+
; CHECK-NEXT: add z1.d, z1.d, z3.d
391+
; CHECK-NEXT: ret
392+
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
393+
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
394+
ret <vscale x 4 x i64> %partial.reduce
395+
}
396+
319397
define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
320398
; CHECK-LABEL: not_udot:
321399
; CHECK: // %bb.0: // %entry

0 commit comments

Comments
 (0)