Skip to content

Commit 1938819

Browse files
committed
skylake dgemm: Add a 16x8 kernel
The next step for the avx512 dgemm code is adding a 16x8 kernel. In the 8x8 kernel, each FMA has a matching load (the broadcast); in the 16x8 kernel we can reuse this load for 2 FMAs, which in turn reduces pressure on the load ports of the CPU and gives a nice performance boost (in the 25% range).
1 parent a980953 commit 1938819

File tree

1 file changed

+154
-1
lines changed

1 file changed

+154
-1
lines changed

kernel/x86_64/dgemm_kernel_4x8_skylakex.c

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,18 +849,171 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double * __restrict__ A,
849849

850850
i = m;
851851

852-
while (i >= 8) {
852+
while (i >= 16) {
853853
double *BO;
854+
double *A1;
854855
int kloop = K;
855856

856857
BO = B + 12;
858+
A1 = AO + 8 * K;
857859
/*
858860
* This is the inner loop for the hot hot path
859861
* Written in inline asm because compilers like GCC 8 and earlier
860862
* struggle with register allocation and are not good at using
861863
* the AVX512 built in broadcast ability (1to8)
862864
*/
863865
asm(
866+
"vxorpd %%zmm1, %%zmm1, %%zmm1\n"
867+
"vmovapd %%zmm1, %%zmm2\n"
868+
"vmovapd %%zmm1, %%zmm3\n"
869+
"vmovapd %%zmm1, %%zmm4\n"
870+
"vmovapd %%zmm1, %%zmm5\n"
871+
"vmovapd %%zmm1, %%zmm6\n"
872+
"vmovapd %%zmm1, %%zmm7\n"
873+
"vmovapd %%zmm1, %%zmm8\n"
874+
"vmovapd %%zmm1, %%zmm11\n"
875+
"vmovapd %%zmm1, %%zmm12\n"
876+
"vmovapd %%zmm1, %%zmm13\n"
877+
"vmovapd %%zmm1, %%zmm14\n"
878+
"vmovapd %%zmm1, %%zmm15\n"
879+
"vmovapd %%zmm1, %%zmm16\n"
880+
"vmovapd %%zmm1, %%zmm17\n"
881+
"vmovapd %%zmm1, %%zmm18\n"
882+
"jmp .label16\n"
883+
".align 32\n"
884+
/* Inner math loop */
885+
".label16:\n"
886+
"vmovupd -128(%[AO]),%%zmm0\n"
887+
"vmovupd -128(%[A1]),%%zmm10\n"
888+
889+
"vbroadcastsd -96(%[BO]), %%zmm9\n"
890+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm1\n"
891+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm11\n"
892+
893+
"vbroadcastsd -88(%[BO]), %%zmm9\n"
894+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm2\n"
895+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm12\n"
896+
897+
"vbroadcastsd -80(%[BO]), %%zmm9\n"
898+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm3\n"
899+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm13\n"
900+
901+
"vbroadcastsd -72(%[BO]), %%zmm9\n"
902+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm4\n"
903+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm14\n"
904+
905+
"vbroadcastsd -64(%[BO]), %%zmm9\n"
906+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm5\n"
907+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm15\n"
908+
909+
"vbroadcastsd -56(%[BO]), %%zmm9\n"
910+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm6\n"
911+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm16\n"
912+
913+
"vbroadcastsd -48(%[BO]), %%zmm9\n"
914+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm7\n"
915+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm17\n"
916+
917+
"vbroadcastsd -40(%[BO]), %%zmm9\n"
918+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm8\n"
919+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm18\n"
920+
"add $64, %[AO]\n"
921+
"add $64, %[A1]\n"
922+
"add $64, %[BO]\n"
923+
"prefetch 512(%[AO])\n"
924+
"prefetch 512(%[A1])\n"
925+
"prefetch 512(%[BO])\n"
926+
"subl $1, %[kloop]\n"
927+
"jg .label16\n"
928+
/* multiply the result by alpha */
929+
"vbroadcastsd (%[alpha]), %%zmm9\n"
930+
"vmulpd %%zmm9, %%zmm1, %%zmm1\n"
931+
"vmulpd %%zmm9, %%zmm2, %%zmm2\n"
932+
"vmulpd %%zmm9, %%zmm3, %%zmm3\n"
933+
"vmulpd %%zmm9, %%zmm4, %%zmm4\n"
934+
"vmulpd %%zmm9, %%zmm5, %%zmm5\n"
935+
"vmulpd %%zmm9, %%zmm6, %%zmm6\n"
936+
"vmulpd %%zmm9, %%zmm7, %%zmm7\n"
937+
"vmulpd %%zmm9, %%zmm8, %%zmm8\n"
938+
"vmulpd %%zmm9, %%zmm11, %%zmm11\n"
939+
"vmulpd %%zmm9, %%zmm12, %%zmm12\n"
940+
"vmulpd %%zmm9, %%zmm13, %%zmm13\n"
941+
"vmulpd %%zmm9, %%zmm14, %%zmm14\n"
942+
"vmulpd %%zmm9, %%zmm15, %%zmm15\n"
943+
"vmulpd %%zmm9, %%zmm16, %%zmm16\n"
944+
"vmulpd %%zmm9, %%zmm17, %%zmm17\n"
945+
"vmulpd %%zmm9, %%zmm18, %%zmm18\n"
946+
/* And store additively in C */
947+
"vaddpd (%[C0]), %%zmm1, %%zmm1\n"
948+
"vaddpd (%[C1]), %%zmm2, %%zmm2\n"
949+
"vaddpd (%[C2]), %%zmm3, %%zmm3\n"
950+
"vaddpd (%[C3]), %%zmm4, %%zmm4\n"
951+
"vaddpd (%[C4]), %%zmm5, %%zmm5\n"
952+
"vaddpd (%[C5]), %%zmm6, %%zmm6\n"
953+
"vaddpd (%[C6]), %%zmm7, %%zmm7\n"
954+
"vaddpd (%[C7]), %%zmm8, %%zmm8\n"
955+
"vmovupd %%zmm1, (%[C0])\n"
956+
"vmovupd %%zmm2, (%[C1])\n"
957+
"vmovupd %%zmm3, (%[C2])\n"
958+
"vmovupd %%zmm4, (%[C3])\n"
959+
"vmovupd %%zmm5, (%[C4])\n"
960+
"vmovupd %%zmm6, (%[C5])\n"
961+
"vmovupd %%zmm7, (%[C6])\n"
962+
"vmovupd %%zmm8, (%[C7])\n"
963+
964+
"vaddpd 64(%[C0]), %%zmm11, %%zmm11\n"
965+
"vaddpd 64(%[C1]), %%zmm12, %%zmm12\n"
966+
"vaddpd 64(%[C2]), %%zmm13, %%zmm13\n"
967+
"vaddpd 64(%[C3]), %%zmm14, %%zmm14\n"
968+
"vaddpd 64(%[C4]), %%zmm15, %%zmm15\n"
969+
"vaddpd 64(%[C5]), %%zmm16, %%zmm16\n"
970+
"vaddpd 64(%[C6]), %%zmm17, %%zmm17\n"
971+
"vaddpd 64(%[C7]), %%zmm18, %%zmm18\n"
972+
"vmovupd %%zmm11, 64(%[C0])\n"
973+
"vmovupd %%zmm12, 64(%[C1])\n"
974+
"vmovupd %%zmm13, 64(%[C2])\n"
975+
"vmovupd %%zmm14, 64(%[C3])\n"
976+
"vmovupd %%zmm15, 64(%[C4])\n"
977+
"vmovupd %%zmm16, 64(%[C5])\n"
978+
"vmovupd %%zmm17, 64(%[C6])\n"
979+
"vmovupd %%zmm18, 64(%[C7])\n"
980+
981+
:
982+
[AO] "+r" (AO),
983+
[A1] "+r" (A1),
984+
[BO] "+r" (BO),
985+
[C0] "+r" (CO1),
986+
[kloop] "+r" (kloop)
987+
:
988+
[alpha] "r" (&alpha),
989+
[C1] "r" (CO1 + 1 * ldc),
990+
[C2] "r" (CO1 + 2 * ldc),
991+
[C3] "r" (CO1 + 3 * ldc),
992+
[C4] "r" (CO1 + 4 * ldc),
993+
[C5] "r" (CO1 + 5 * ldc),
994+
[C6] "r" (CO1 + 6 * ldc),
995+
[C7] "r" (CO1 + 7 * ldc)
996+
997+
: "memory", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9",
998+
"zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18"
999+
);
1000+
CO1 += 16;
1001+
AO += 8 * K;
1002+
i-= 16;
1003+
}
1004+
1005+
while (i >= 8) {
1006+
double *BO;
1007+
int kloop = K;
1008+
1009+
BO = B + 12;
1010+
/*
1011+
* This is the inner loop for the hot hot path
1012+
* Written in inline asm because compilers like GCC 8 and earlier
1013+
* struggle with register allocation and are not good at using
1014+
* the AVX512 built in broadcast ability (1to8)
1015+
*/
1016+
asm(
8641017
"vxorpd %%zmm1, %%zmm1, %%zmm1\n"
8651018
"vmovapd %%zmm1, %%zmm2\n"
8661019
"vmovapd %%zmm1, %%zmm3\n"

0 commit comments

Comments
 (0)