Skip to content

Commit 66b43af

Browse files
committed
Add a 24x8 kernel to the skylakex dgemm implementation
Minor gains for small matrixes, but at 512x512 and above the gain gets more significant.
1 parent 1938819 commit 66b43af

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

kernel/x86_64/dgemm_kernel_4x8_skylakex.c

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,207 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double * __restrict__ A,
849849

850850
i = m;
851851

852+
while (i >= 24) {
853+
double *BO;
854+
double *A1, *A2;
855+
int kloop = K;
856+
857+
BO = B + 12;
858+
A1 = AO + 8 * K;
859+
A2 = AO + 16 * K;
860+
/*
861+
* This is the inner loop for the hot hot path
862+
* Written in inline asm because compilers like GCC 8 and earlier
863+
* struggle with register allocation and are not good at using
864+
* the AVX512 built in broadcast ability (1to8)
865+
*/
866+
asm(
867+
"vxorpd %%zmm1, %%zmm1, %%zmm1\n"
868+
"vmovapd %%zmm1, %%zmm2\n"
869+
"vmovapd %%zmm1, %%zmm3\n"
870+
"vmovapd %%zmm1, %%zmm4\n"
871+
"vmovapd %%zmm1, %%zmm5\n"
872+
"vmovapd %%zmm1, %%zmm6\n"
873+
"vmovapd %%zmm1, %%zmm7\n"
874+
"vmovapd %%zmm1, %%zmm8\n"
875+
"vmovapd %%zmm1, %%zmm11\n"
876+
"vmovapd %%zmm1, %%zmm12\n"
877+
"vmovapd %%zmm1, %%zmm13\n"
878+
"vmovapd %%zmm1, %%zmm14\n"
879+
"vmovapd %%zmm1, %%zmm15\n"
880+
"vmovapd %%zmm1, %%zmm16\n"
881+
"vmovapd %%zmm1, %%zmm17\n"
882+
"vmovapd %%zmm1, %%zmm18\n"
883+
"vmovapd %%zmm1, %%zmm21\n"
884+
"vmovapd %%zmm1, %%zmm22\n"
885+
"vmovapd %%zmm1, %%zmm23\n"
886+
"vmovapd %%zmm1, %%zmm24\n"
887+
"vmovapd %%zmm1, %%zmm25\n"
888+
"vmovapd %%zmm1, %%zmm26\n"
889+
"vmovapd %%zmm1, %%zmm27\n"
890+
"vmovapd %%zmm1, %%zmm28\n"
891+
"jmp .label24\n"
892+
".align 32\n"
893+
/* Inner math loop */
894+
".label24:\n"
895+
"vmovupd -128(%[AO]),%%zmm0\n"
896+
"vmovupd -128(%[A1]),%%zmm10\n"
897+
"vmovupd -128(%[A2]),%%zmm20\n"
898+
899+
"vbroadcastsd -96(%[BO]), %%zmm9\n"
900+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm1\n"
901+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm11\n"
902+
"vfmadd231pd %%zmm9, %%zmm20, %%zmm21\n"
903+
904+
"vbroadcastsd -88(%[BO]), %%zmm9\n"
905+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm2\n"
906+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm12\n"
907+
"vfmadd231pd %%zmm9, %%zmm20, %%zmm22\n"
908+
909+
"vbroadcastsd -80(%[BO]), %%zmm9\n"
910+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm3\n"
911+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm13\n"
912+
"vfmadd231pd %%zmm9, %%zmm20, %%zmm23\n"
913+
914+
"vbroadcastsd -72(%[BO]), %%zmm9\n"
915+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm4\n"
916+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm14\n"
917+
"vfmadd231pd %%zmm9, %%zmm20, %%zmm24\n"
918+
919+
"vbroadcastsd -64(%[BO]), %%zmm9\n"
920+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm5\n"
921+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm15\n"
922+
"vfmadd231pd %%zmm9, %%zmm20, %%zmm25\n"
923+
924+
"vbroadcastsd -56(%[BO]), %%zmm9\n"
925+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm6\n"
926+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm16\n"
927+
"vfmadd231pd %%zmm9, %%zmm20, %%zmm26\n"
928+
929+
"vbroadcastsd -48(%[BO]), %%zmm9\n"
930+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm7\n"
931+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm17\n"
932+
"vfmadd231pd %%zmm9, %%zmm20, %%zmm27\n"
933+
934+
"vbroadcastsd -40(%[BO]), %%zmm9\n"
935+
"vfmadd231pd %%zmm9, %%zmm0, %%zmm8\n"
936+
"vfmadd231pd %%zmm9, %%zmm10, %%zmm18\n"
937+
"vfmadd231pd %%zmm9, %%zmm20, %%zmm28\n"
938+
"add $64, %[AO]\n"
939+
"add $64, %[A1]\n"
940+
"add $64, %[A2]\n"
941+
"add $64, %[BO]\n"
942+
"prefetch 512(%[AO])\n"
943+
"prefetch 512(%[A1])\n"
944+
"prefetch 512(%[A2])\n"
945+
"prefetch 512(%[BO])\n"
946+
"subl $1, %[kloop]\n"
947+
"jg .label24\n"
948+
/* multiply the result by alpha */
949+
"vbroadcastsd (%[alpha]), %%zmm9\n"
950+
"vmulpd %%zmm9, %%zmm1, %%zmm1\n"
951+
"vmulpd %%zmm9, %%zmm2, %%zmm2\n"
952+
"vmulpd %%zmm9, %%zmm3, %%zmm3\n"
953+
"vmulpd %%zmm9, %%zmm4, %%zmm4\n"
954+
"vmulpd %%zmm9, %%zmm5, %%zmm5\n"
955+
"vmulpd %%zmm9, %%zmm6, %%zmm6\n"
956+
"vmulpd %%zmm9, %%zmm7, %%zmm7\n"
957+
"vmulpd %%zmm9, %%zmm8, %%zmm8\n"
958+
"vmulpd %%zmm9, %%zmm11, %%zmm11\n"
959+
"vmulpd %%zmm9, %%zmm12, %%zmm12\n"
960+
"vmulpd %%zmm9, %%zmm13, %%zmm13\n"
961+
"vmulpd %%zmm9, %%zmm14, %%zmm14\n"
962+
"vmulpd %%zmm9, %%zmm15, %%zmm15\n"
963+
"vmulpd %%zmm9, %%zmm16, %%zmm16\n"
964+
"vmulpd %%zmm9, %%zmm17, %%zmm17\n"
965+
"vmulpd %%zmm9, %%zmm18, %%zmm18\n"
966+
"vmulpd %%zmm9, %%zmm21, %%zmm21\n"
967+
"vmulpd %%zmm9, %%zmm22, %%zmm22\n"
968+
"vmulpd %%zmm9, %%zmm23, %%zmm23\n"
969+
"vmulpd %%zmm9, %%zmm24, %%zmm24\n"
970+
"vmulpd %%zmm9, %%zmm25, %%zmm25\n"
971+
"vmulpd %%zmm9, %%zmm26, %%zmm26\n"
972+
"vmulpd %%zmm9, %%zmm27, %%zmm27\n"
973+
"vmulpd %%zmm9, %%zmm28, %%zmm28\n"
974+
/* And store additively in C */
975+
"vaddpd (%[C0]), %%zmm1, %%zmm1\n"
976+
"vaddpd (%[C1]), %%zmm2, %%zmm2\n"
977+
"vaddpd (%[C2]), %%zmm3, %%zmm3\n"
978+
"vaddpd (%[C3]), %%zmm4, %%zmm4\n"
979+
"vaddpd (%[C4]), %%zmm5, %%zmm5\n"
980+
"vaddpd (%[C5]), %%zmm6, %%zmm6\n"
981+
"vaddpd (%[C6]), %%zmm7, %%zmm7\n"
982+
"vaddpd (%[C7]), %%zmm8, %%zmm8\n"
983+
"vmovupd %%zmm1, (%[C0])\n"
984+
"vmovupd %%zmm2, (%[C1])\n"
985+
"vmovupd %%zmm3, (%[C2])\n"
986+
"vmovupd %%zmm4, (%[C3])\n"
987+
"vmovupd %%zmm5, (%[C4])\n"
988+
"vmovupd %%zmm6, (%[C5])\n"
989+
"vmovupd %%zmm7, (%[C6])\n"
990+
"vmovupd %%zmm8, (%[C7])\n"
991+
992+
"vaddpd 64(%[C0]), %%zmm11, %%zmm11\n"
993+
"vaddpd 64(%[C1]), %%zmm12, %%zmm12\n"
994+
"vaddpd 64(%[C2]), %%zmm13, %%zmm13\n"
995+
"vaddpd 64(%[C3]), %%zmm14, %%zmm14\n"
996+
"vaddpd 64(%[C4]), %%zmm15, %%zmm15\n"
997+
"vaddpd 64(%[C5]), %%zmm16, %%zmm16\n"
998+
"vaddpd 64(%[C6]), %%zmm17, %%zmm17\n"
999+
"vaddpd 64(%[C7]), %%zmm18, %%zmm18\n"
1000+
"vmovupd %%zmm11, 64(%[C0])\n"
1001+
"vmovupd %%zmm12, 64(%[C1])\n"
1002+
"vmovupd %%zmm13, 64(%[C2])\n"
1003+
"vmovupd %%zmm14, 64(%[C3])\n"
1004+
"vmovupd %%zmm15, 64(%[C4])\n"
1005+
"vmovupd %%zmm16, 64(%[C5])\n"
1006+
"vmovupd %%zmm17, 64(%[C6])\n"
1007+
"vmovupd %%zmm18, 64(%[C7])\n"
1008+
1009+
"vaddpd 128(%[C0]), %%zmm21, %%zmm21\n"
1010+
"vaddpd 128(%[C1]), %%zmm22, %%zmm22\n"
1011+
"vaddpd 128(%[C2]), %%zmm23, %%zmm23\n"
1012+
"vaddpd 128(%[C3]), %%zmm24, %%zmm24\n"
1013+
"vaddpd 128(%[C4]), %%zmm25, %%zmm25\n"
1014+
"vaddpd 128(%[C5]), %%zmm26, %%zmm26\n"
1015+
"vaddpd 128(%[C6]), %%zmm27, %%zmm27\n"
1016+
"vaddpd 128(%[C7]), %%zmm28, %%zmm28\n"
1017+
"vmovupd %%zmm21, 128(%[C0])\n"
1018+
"vmovupd %%zmm22, 128(%[C1])\n"
1019+
"vmovupd %%zmm23, 128(%[C2])\n"
1020+
"vmovupd %%zmm24, 128(%[C3])\n"
1021+
"vmovupd %%zmm25, 128(%[C4])\n"
1022+
"vmovupd %%zmm26, 128(%[C5])\n"
1023+
"vmovupd %%zmm27, 128(%[C6])\n"
1024+
"vmovupd %%zmm28, 128(%[C7])\n"
1025+
1026+
:
1027+
[AO] "+r" (AO),
1028+
[A1] "+r" (A1),
1029+
[A2] "+r" (A2),
1030+
[BO] "+r" (BO),
1031+
[C0] "+r" (CO1),
1032+
[kloop] "+r" (kloop)
1033+
:
1034+
[alpha] "r" (&alpha),
1035+
[C1] "r" (CO1 + 1 * ldc),
1036+
[C2] "r" (CO1 + 2 * ldc),
1037+
[C3] "r" (CO1 + 3 * ldc),
1038+
[C4] "r" (CO1 + 4 * ldc),
1039+
[C5] "r" (CO1 + 5 * ldc),
1040+
[C6] "r" (CO1 + 6 * ldc),
1041+
[C7] "r" (CO1 + 7 * ldc)
1042+
1043+
: "memory", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9",
1044+
"zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18",
1045+
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28"
1046+
);
1047+
CO1 += 24;
1048+
AO += 16 * K;
1049+
i-= 24;
1050+
}
1051+
1052+
8521053
while (i >= 16) {
8531054
double *BO;
8541055
double *A1;

0 commit comments

Comments
 (0)