Skip to content

Commit eb931de

Browse files
authored
Add BLAS interface to ?GEMM_BATCH
1 parent 79a1f38 commit eb931de

File tree

2 files changed

+75
-10
lines changed

2 files changed

+75
-10
lines changed

interface/Makefile

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,16 @@ SBLAS3OBJS = \
7272
sgemm.$(SUFFIX) ssymm.$(SUFFIX) strmm.$(SUFFIX) \
7373
strsm.$(SUFFIX) ssyrk.$(SUFFIX) ssyr2k.$(SUFFIX) \
7474
somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\
75-
sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX)
75+
sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX) \
76+
sgemm_batch.$(SUFFIX)
7677

7778
ifeq ($(BUILD_BFLOAT16),1)
7879
BBLAS3OBJS = bgemm.$(SUFFIX)
7980
BBLAS2OBJS = bgemv.$(SUFFIX)
8081
BBLAS1OBJS = bscal.$(SUFFIX)
8182
SBBLAS1OBJS = sbdot.$(SUFFIX)
8283
SBBLAS2OBJS = sbgemv.$(SUFFIX)
83-
SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX)
84+
SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) sbgemm_batch.$(SUFFIX)
8485
SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
8586
endif
8687

@@ -111,7 +112,8 @@ DBLAS3OBJS = \
111112
dgemm.$(SUFFIX) dsymm.$(SUFFIX) dtrmm.$(SUFFIX) \
112113
dtrsm.$(SUFFIX) dsyrk.$(SUFFIX) dsyr2k.$(SUFFIX) \
113114
domatcopy.$(SUFFIX) dimatcopy.$(SUFFIX)\
114-
dgeadd.$(SUFFIX) dgemmt.$(SUFFIX) dgemmtr.$(SUFFIX)
115+
dgeadd.$(SUFFIX) dgemmt.$(SUFFIX) dgemmtr.$(SUFFIX) \
116+
dgemm_batch.$(SUFFIX)
115117

116118
CBLAS1OBJS = \
117119
caxpy.$(SUFFIX) caxpyc.$(SUFFIX) cswap.$(SUFFIX) \
@@ -140,7 +142,8 @@ CBLAS3OBJS = \
140142
ctrsm.$(SUFFIX) csyrk.$(SUFFIX) csyr2k.$(SUFFIX) \
141143
chemm.$(SUFFIX) cherk.$(SUFFIX) cher2k.$(SUFFIX) \
142144
comatcopy.$(SUFFIX) cimatcopy.$(SUFFIX)\
143-
cgeadd.$(SUFFIX) cgemmt.$(SUFFIX) cgemmtr.$(SUFFIX)
145+
cgeadd.$(SUFFIX) cgemmt.$(SUFFIX) cgemmtr.$(SUFFIX) \
146+
cgemm_batch.$(SUFFIX)
144147

145148
ZBLAS1OBJS = \
146149
zaxpy.$(SUFFIX) zaxpyc.$(SUFFIX) zswap.$(SUFFIX) \
@@ -169,7 +172,8 @@ ZBLAS3OBJS = \
169172
ztrsm.$(SUFFIX) zsyrk.$(SUFFIX) zsyr2k.$(SUFFIX) \
170173
zhemm.$(SUFFIX) zherk.$(SUFFIX) zher2k.$(SUFFIX) \
171174
zomatcopy.$(SUFFIX) zimatcopy.$(SUFFIX)\
172-
zgeadd.$(SUFFIX) zgemmt.$(SUFFIX) zgemmtr.$(SUFFIX)
175+
zgeadd.$(SUFFIX) zgemmt.$(SUFFIX) zgemmtr.$(SUFFIX) \
176+
zgemm_batch.$(SUFFIX)
173177

174178
ifeq ($(SUPPORT_GEMM3M), 1)
175179

@@ -2539,3 +2543,19 @@ cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param
25392543

25402544
cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
25412545
$(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F)
2546+
2547+
sbgemm_batch.$(SUFFIX) sbgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2548+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2549+
2550+
sgemm_batch.$(SUFFIX) sgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2551+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2552+
2553+
dgemm_batch.$(SUFFIX) dgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2554+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2555+
2556+
cgemm_batch.$(SUFFIX) cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2557+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2558+
2559+
zgemm_batch.$(SUFFIX) zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h
2560+
$(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F)
2561+

interface/gemm_batch.c

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ static size_t zgemm_small_kernel_b0[] = {
114114
#endif
115115
#endif
116116

117+
#ifndef CBLAS
118+
void CNAME(char *transa_array, char *transb_array,
119+
blasint * m_array, blasint * n_array, blasint * k_array,
120+
FLOAT * alpha_array,
121+
IFLOAT ** a_array, blasint * lda_array,
122+
IFLOAT ** b_array, blasint * ldb_array,
123+
FLOAT * beta_array,
124+
FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) {
125+
126+
#else
127+
117128
void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array,
118129
blasint * m_array, blasint * n_array, blasint * k_array,
119130
#ifndef COMPLEX
@@ -134,8 +145,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
134145
FLOAT ** a_array=(FLOAT**)va_array;
135146
FLOAT ** b_array=(FLOAT**)vb_array;
136147
FLOAT ** c_array=(FLOAT**)vc_array;
137-
138148
#endif
149+
#endif
150+
BLASLONG group_m, group_n, group_k;
151+
BLASLONG group_lda, group_ldb, group_ldc;
152+
139153
blas_arg_t * args_array=NULL;
140154

141155
int mode=0, group_mode=0;
@@ -148,8 +162,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
148162
blasint info;
149163

150164
void * group_alpha, * group_beta;
151-
BLASLONG group_m, group_n, group_k;
152-
BLASLONG group_lda, group_ldb, group_ldc;
153165
void * group_routine=NULL;
154166
#ifdef SMALL_MATRIX_OPT
155167
void * group_small_matrix_opt_routine=NULL;
@@ -201,7 +213,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
201213
group_transa = -1;
202214
group_transb = -1;
203215
info = 0;
204-
216+
217+
#if defined(CBLAS)
205218
if (order == CblasColMajor) {
206219
group_m = m_array[i];
207220
group_n = n_array[i];
@@ -254,7 +267,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
254267
group_lda = ldb_array[i];
255268
group_ldb = lda_array[i];
256269
group_ldc = ldc_array[i];
257-
270+
258271
if (transb_array[i] == CblasNoTrans) group_transa = 0;
259272
if (transb_array[i] == CblasTrans) group_transa = 1;
260273
#ifndef COMPLEX
@@ -273,6 +286,32 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
273286
if (transa_array[i] == CblasConjNoTrans) group_transb = 2;
274287
if (transa_array[i] == CblasConjTrans) group_transb = 3;
275288
#endif
289+
290+
#else
291+
group_m = m_array[i];
292+
group_n = n_array[i];
293+
group_k = k_array[i];
294+
295+
group_lda = lda_array[i];
296+
group_ldb = ldb_array[i];
297+
group_ldc = ldc_array[i];
298+
299+
if (transb_array[i] == 'N') group_transa = 0;
300+
if (transb_array[i] == 'T') group_transa = 1;
301+
#ifndef COMPLEX
302+
if (transb_array[i] == 'C') group_transa = 1;
303+
#else
304+
if (transb_array[i] == 'C') group_transa = 3;
305+
#endif
306+
if (transa_array[i] == 'N') group_transb = 0;
307+
if (transa_array[i] == 'T') group_transb = 1;
308+
#ifndef COMPLEX
309+
if (transa_array[i] == 'C') group_transb = 1;
310+
#else
311+
if (transa_array[i] == 'C') group_transb = 3;
312+
#endif
313+
#endif
314+
276315
group_nrowa = group_m;
277316
if (group_transa & 1) group_nrowa = group_k;
278317
group_nrowb = group_k;
@@ -288,7 +327,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
288327
if (group_m < 0) info = 3;
289328
if (group_transb < 0) info = 2;
290329
if (group_transa < 0) info = 1;
330+
#if defined(CBLAS)
291331
}
332+
#endif
292333

293334
if (info >= 0) {
294335
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
@@ -344,13 +385,17 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB
344385
args_array[count].alpha=group_alpha;
345386
args_array[count].beta=group_beta;
346387

388+
#if defined(CBLAS)
347389
if (order == CblasColMajor) {
348390
args_array[count].a=(a_array[matrix_idx+j]);
349391
args_array[count].b=(b_array[matrix_idx+j]);
350392
}else if(order == CblasRowMajor){
393+
#endif
351394
args_array[count].a=(b_array[matrix_idx+j]);
352395
args_array[count].b=(a_array[matrix_idx+j]);
396+
#if defined(CBLAS)
353397
}
398+
#endif
354399

355400
args_array[count].c=(c_array[matrix_idx+j]);
356401

0 commit comments

Comments
 (0)