Skip to content

Commit 78d877b

Browse files
authored
Merge pull request #1914 from fenrus75/smallmatrix
Add a "sgemm direct" mode for small matrixes
2 parents 8771880 + cdc668d commit 78d877b

File tree

4 files changed

+483
-1
lines changed

4 files changed

+483
-1
lines changed

common_level3.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ __global__ void cuda_dgemm_kernel(int, int, int, double *, double *, double *);
4747
extern "C" {
4848
#endif
4949

50+
extern void sgemm_kernel_direct(BLASLONG M, BLASLONG N, BLASLONG K,
51+
float * A, BLASLONG strideA,
52+
float * B, BLASLONG strideB,
53+
float * R, BLASLONG strideR);
54+
55+
extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
56+
57+
5058
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
5159
float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
5260
int dgemm_beta(BLASLONG, BLASLONG, BLASLONG, double,

interface/gemm.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,14 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
271271

272272
PRINT_DEBUG_CNAME;
273273

274+
#if !defined(COMPLEX) && !defined(DOUBLE) && defined(USE_SGEMM_KERNEL_DIRECT)
275+
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && sgemm_kernel_direct_performant(m,n,k)) {
276+
sgemm_kernel_direct(m, n, k, a, lda, b, ldb, c, ldc);
277+
return;
278+
}
279+
280+
#endif
281+
274282
#ifndef COMPLEX
275283
args.alpha = (void *)α
276284
args.beta = (void *)β

0 commit comments

Comments
 (0)