Skip to content

Commit 75eeb26

Browse files
authored
[WIP] Refactor the driver code for direct SGEMM (#2782)
Move "direct SGEMM" functionality out of the SkylakeX SGEMM kernel and make it available (on x86_64 targets only for now) in DYNAMIC_ARCH builds * Add sgemm_direct targets in the kernel Makefile.L3 and CMakeLists.txt * Add direct_sgemm functions to the gotoblas struct in common_param.h * Move sgemm_direct_performant helper to separate file * Update gemm.c to macros for sgemm_direct to support dynamic_arch naming via common_s,h * (Conditionally) add sgemm_direct functions in setparam-ref.c
1 parent 2c72972 commit 75eeb26

File tree

10 files changed

+107
-10
lines changed

10 files changed

+107
-10
lines changed

common_level3.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ __global__ void cuda_dgemm_kernel(int, int, int, double *, double *, double *);
4747
extern "C" {
4848
#endif
4949

50-
extern void sgemm_kernel_direct(BLASLONG M, BLASLONG N, BLASLONG K,
50+
void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
5151
float * A, BLASLONG strideA,
5252
float * B, BLASLONG strideB,
5353
float * R, BLASLONG strideR);
5454

55-
extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
55+
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
5656

5757

5858
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
175175
int (*ssymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
176176
int (*ssymv_U) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
177177

178+
#ifdef ARCH_X86_64
179+
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
180+
int (*sgemm_direct_performant) (BLASLONG M, BLASLONG N, BLASLONG K);
181+
#endif
182+
178183
int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
179184
int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
180185

common_s.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
#define SSYMV_THREAD_U ssymv_thread_U
4646
#define SSYMV_THREAD_L ssymv_thread_L
4747

48+
49+
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
50+
#define SGEMM_DIRECT sgemm_direct
51+
4852
#define SGEMM_ONCOPY sgemm_oncopy
4953
#define SGEMM_OTCOPY sgemm_otcopy
5054

@@ -204,6 +208,14 @@
204208
#define SSYMV_THREAD_U ssymv_thread_U
205209
#define SSYMV_THREAD_L ssymv_thread_L
206210

211+
#ifdef ARCH_X86_64
212+
#define SGEMM_DIRECT_PERFORMANT gotoblas -> sgemm_direct_performant
213+
#define SGEMM_DIRECT gotoblas -> sgemm_direct
214+
#else
215+
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
216+
#define SGEMM_DIRECT sgemm_direct
217+
#endif
218+
207219
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
208220
#define SGEMM_OTCOPY gotoblas -> sgemm_otcopy
209221
#define SGEMM_INCOPY gotoblas -> sgemm_incopy

interface/gemm.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
275275
#ifdef DYNAMIC_ARCH
276276
if (support_avx512() )
277277
#endif
278-
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && sgemm_kernel_direct_performant(m,n,k)) {
279-
sgemm_kernel_direct(m, n, k, a, lda, b, ldb, c, ldc);
278+
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
279+
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
280280
return;
281281
}
282282

kernel/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,20 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
134134
set(USE_TRMM true)
135135
endif ()
136136

137+
set(USE_DIRECT_SGEMM false)
138+
if (X86_64)
139+
set(USE_DIRECT_SGEMM true)
140+
endif()
141+
142+
if (USE_DIRECT_SGEMM)
143+
# if (NOT DEFINED SGEMMDIRECTKERNEL)
144+
set (SGEMMDIRECTKERNEL sgemm_direct_skylakex.c)
145+
set (SGEMMDIRECTPERFORMANT sgemm_direct_performant.c)
146+
# endif()
147+
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
148+
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE)
149+
endif()
150+
137151
foreach (float_type SINGLE DOUBLE HALF)
138152
string(SUBSTRING ${float_type} 0 1 float_char)
139153
if (${float_type} STREQUAL "HALF")

kernel/Makefile.L3

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ ifeq ($(ARCH), x86_64)
99
USE_GEMM3M = 1
1010
endif
1111

12+
ifeq ($(ARCH), x86_64)
13+
USE_DIRECT_SGEMM = 1
14+
endif
15+
1216
ifeq ($(ARCH), ia64)
1317
USE_GEMM3M = 1
1418
endif
@@ -65,6 +69,13 @@ ifeq ($(CORE), Z14)
6569
USE_TRMM = 1
6670
endif
6771

72+
ifdef USE_DIRECT_SGEMM
73+
ifndef SGEMMDIRECTKERNEL
74+
SGEMMDIRECTKERNEL = sgemm_direct_skylakex.c
75+
SGEMMDIRECTPERFORMANT = sgemm_direct_performant.c
76+
endif
77+
endif
78+
6879
ifeq ($(BUILD_HALF), 1)
6980
ifndef SHGEMMKERNEL
7081
SHGEMM_BETA = ../generic/gemm_beta.c
@@ -90,6 +101,12 @@ SKERNELOBJS += \
90101
$(SGEMMINCOPYOBJ) $(SGEMMITCOPYOBJ) \
91102
$(SGEMMONCOPYOBJ) $(SGEMMOTCOPYOBJ)
92103

104+
ifdef USE_DIRECT_SGEMM
105+
SKERNELOBJS += \
106+
sgemm_direct$(TSUFFIX).$(SUFFIX) \
107+
sgemm_direct_performant$(TSUFFIX).$(SUFFIX)
108+
endif
109+
93110
DKERNELOBJS += \
94111
dgemm_kernel$(TSUFFIX).$(SUFFIX) \
95112
$(DGEMMINCOPYOBJ) $(DGEMMITCOPYOBJ) \
@@ -668,6 +685,13 @@ else
668685
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
669686
endif
670687

688+
ifdef USE_DIRECT_SGEMM
689+
$(KDIR)sgemm_direct_performant$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTPERFORMANT)
690+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
691+
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL)
692+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
693+
endif
694+
671695
ifeq ($(BUILD_HALF), 1)
672696

673697
$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)

kernel/setparam-ref.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ gotoblas_t TABLE_NAME = {
135135
sgemv_nTS, sgemv_tTS, sger_kTS,
136136
ssymv_LTS, ssymv_UTS,
137137

138+
#ifdef ARCH_X86_64
139+
sgemm_directTS,
140+
sgemm_direct_performantTS,
141+
#endif
142+
138143
sgemm_kernelTS, sgemm_betaTS,
139144
#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N
140145
sgemm_incopyTS, sgemm_itcopyTS,
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "common.h"
2+
/* helper for the direct sgemm code written by Arjan van der Ven */
3+
4+
5+
6+
7+
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K)
8+
{
9+
unsigned long long mnk = M * N * K;
10+
/* large matrixes -> not performant */
11+
if (mnk >= 28 * 512 * 512)
12+
return 0;
13+
14+
/*
15+
* if the B matrix is not a nice multiple if 4 we get many unaligned accesses,
16+
* and the regular sgemm copy/realignment of data pays off much quicker
17+
*/
18+
if ((N & 3) != 0 && (mnk >= 8 * 512 * 512))
19+
return 0;
20+
21+
#ifdef SMP
22+
/* if we can run multithreaded, the threading changes the based threshold */
23+
if (mnk > 2 * 350 * 512 && num_cpu_avail(3)> 1)
24+
return 0;
25+
#endif
26+
27+
return 1;
28+
}
29+
30+

kernel/x86_64/sgemm_direct_skylakex.c

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
1+
#if defined(SKYLAKEX) || defined (COOPERLAKE)
22
/* the direct sgemm code written by Arjan van der Ven */
3-
//#include <immintrin.h>
4-
3+
#include <immintrin.h>
4+
#include "common.h"
55
/*
66
* "Direct sgemm" code. This code operates directly on the inputs and outputs
77
* of the sgemm call, avoiding the copies, memory realignments and threading,
@@ -38,6 +38,7 @@
3838
#define MATMUL_SCALAR(N,M) result##N##M += Aval##M * Bval##N;
3939
#define STORE_SCALAR(N,M) R[(i+M) * strideR + j + N] = result##N##M;
4040

41+
#if 0
4142
int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K)
4243
{
4344
unsigned long long mnk = M * N * K;
@@ -61,9 +62,10 @@ int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K)
6162
return 1;
6263
}
6364

65+
#endif
6466

65-
66-
void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
67+
//void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
68+
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
6769
{
6870
int i, j, k;
6971

@@ -465,3 +467,8 @@ void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict
465467
}
466468
}
467469
}
470+
#else
471+
#include "common.h"
472+
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
473+
{}
474+
#endif

kernel/x86_64/sgemm_kernel_16x4_skylakex_3.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,4 +512,4 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict__ A, f
512512
return 0;
513513
}
514514
#include <immintrin.h>
515-
#include "sgemm_direct_skylakex.c"
515+
//#include "sgemm_direct_skylakex.c"

0 commit comments

Comments
 (0)