Skip to content

Commit ac8cbfd

Browse files
authored
Merge pull request #5381 from Mousius/bgemv-infrastructure
Add infrastructure for BGEMV
2 parents 1742dec + 947d7af commit ac8cbfd

34 files changed

+919
-167
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ test/ZBLAT2.SUMM
8181
test/ZBLAT3.SUMM
8282
test/ZBLAT3_3M.SUMM
8383
test/SHBLAT3.SUMM
84+
test/SBBLAT2.SUMM
8485
test/SBBLAT3.SUMM
86+
test/BBLAT2.SUMM
8587
test/BBLAT3.SUMM
8688
test/cblat1
8789
test/cblat2
@@ -97,7 +99,9 @@ test/sblat3
9799
test/sblat3_3m
98100
test/test_shgemm
99101
test/test_sbgemm
102+
test/test_sbgemv
100103
test/test_bgemm
104+
test/test_bgemv
101105
test/zblat1
102106
test/zblat2
103107
test/zblat3

cblas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ void cblas_sbdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPEN
465465
void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout);
466466
/* convert BFLOAT16 array to double array */
467467
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
468+
void cblas_bgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 beta, bfloat16 *y, OPENBLAS_CONST blasint incy);
468469
/* dot production of BFLOAT16 input arrays, and output as float */
469470
float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
470471
void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy);

cmake/kernel.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ macro(SetDefaultL1)
110110
SetFallback(SROTMKERNEL rotm.S)
111111
SetFallback(DROTMKERNEL rotm.S)
112112
SetFallback(QROTMKERNEL rotm.S)
113+
SetFallback(BSCALKERNEL ../generic/scal.c)
113114
SetFallback(SSCALKERNEL scal.S)
114115
SetFallback(DSCALKERNEL scal.S)
115116
SetFallback(CSCALKERNEL zscal.S)
@@ -169,6 +170,8 @@ if (BUILD_BFLOAT16)
169170
SetFallback(SHSWAPKERNEL ../arm/swap.c)
170171
SetFallback(TOBF16KERNEL ../x86_64/tobf16.c)
171172
SetFallback(BF16TOKERNEL ../x86_64/bf16to.c)
173+
SetFallback(BGEMVNKERNEL ../generic/gemv_n.c)
174+
SetFallback(BGEMVTKERNEL ../generic/gemv_t.c)
172175
SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c)
173176
SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c)
174177
endif ()
@@ -221,6 +224,8 @@ macro(SetDefaultL2)
221224
SetFallback(XHEMV_V_KERNEL ../generic/zhemv_k.c)
222225
SetFallback(XHEMV_M_KERNEL ../generic/zhemv_k.c)
223226
if (BUILD_BFLOAT16)
227+
SetFallback(BGEMVNKERNEL ../generic/gemv_n.c)
228+
SetFallback(BGEMVTKERNEL ../generic/gemv_t.c)
224229
SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c)
225230
SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c)
226231
SetFallback(SHGERKERNEL ../generic/ger.c)

cmake/utils.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ function(GenerateNamedObjects sources_in)
375375
if (NOT no_float_type)
376376
string(SUBSTRING ${float_type} 0 1 float_char)
377377
string(TOLOWER ${float_char} float_char)
378-
if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEMM")
378+
if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEM")
379379
set (float_char "sb")
380380
endif ()
381381
endif ()

common_b.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
#define COMMON_B_H
3131

3232
#ifndef DYNAMIC_ARCH
33+
#define BGEMV_N_K bgemv_n
34+
#define BGEMV_T_K bgemv_t
35+
36+
#define BSCAL_K bscal_k
37+
3338
#define BGEMM_ONCOPY bgemm_oncopy
3439
#define BGEMM_OTCOPY bgemm_otcopy
3540

@@ -45,6 +50,10 @@
4550
#define BGEMM_KERNEL bgemm_kernel
4651

4752
#else
53+
#define BGEMV_N_K gotoblas->bgemv_n
54+
#define BGEMV_T_K gotoblas->bgemv_t
55+
56+
#define BSCAL_K gotoblas->bscal_k
4857

4958
#define BGEMM_ONCOPY gotoblas->bgemm_oncopy
5059
#define BGEMM_OTCOPY gotoblas->bgemm_otcopy

common_interface.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *);
6060
double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *);
6161
xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *);
6262

63+
void BLASFUNC(bscal) (blasint *, bfloat16 *, bfloat16 *, blasint *);
6364
float BLASFUNC(sbdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *);
6465
void BLASFUNC(sbstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *);
6566
void BLASFUNC(sbdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *);
@@ -256,6 +257,8 @@ void BLASFUNC(xgeru)(blasint *, blasint *, xdouble *, xdouble *, blasint *,
256257
void BLASFUNC(xgerc)(blasint *, blasint *, xdouble *, xdouble *, blasint *,
257258
xdouble *, blasint *, xdouble *, blasint *);
258259

260+
void BLASFUNC(bgemv)(char *, blasint *, blasint *, bfloat16 *, bfloat16 *, blasint *,
261+
bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *);
259262
void BLASFUNC(sbgemv)(char *, blasint *, blasint *, float *, bfloat16 *, blasint *,
260263
bfloat16 *, blasint *, float *, float *, blasint *);
261264
void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *,

common_level1.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*********************************************************************/
2+
/* Copyright 2025 The OpenBLAS Project. */
23
/* Copyright 2009, 2010 The University of Texas at Austin. */
34
/* All rights reserved. */
45
/* */
@@ -169,6 +170,9 @@ BLASLONG icmin_k(BLASLONG, float *, BLASLONG);
169170
BLASLONG izmin_k(BLASLONG, double *, BLASLONG);
170171
BLASLONG ixmin_k(BLASLONG, xdouble *, BLASLONG);
171172

173+
174+
int bscal_k(BLASLONG, BLASLONG, BLASLONG, bfloat16,
175+
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
172176
int sscal_k(BLASLONG, BLASLONG, BLASLONG, float,
173177
float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
174178
int dscal_k(BLASLONG, BLASLONG, BLASLONG, double,

common_level2.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*********************************************************************/
2+
/* Copyright 2025 The OpenBLAS Project */
23
/* Copyright 2009, 2010 The University of Texas at Austin. */
34
/* All rights reserved. */
45
/* */
@@ -44,6 +45,11 @@
4445
extern "C" {
4546
#endif
4647

48+
49+
int bgemv_n(BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG);
50+
int bgemv_t(BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG);
51+
int bgemv_thread_n(BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG, int);
52+
int bgemv_thread_t(BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG, int);
4753
int sbgemv_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
4854
int sbgemv_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
4955
int sbgemv_thread_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);

common_macro.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,11 @@
705705

706706

707707
#elif defined(BFLOAT16) && defined(BGEMM)
708+
#define SCAL_K BSCAL_K
709+
710+
#define GEMV_N BGEMV_N_K
711+
#define GEMV_T BGEMV_T_K
712+
708713
#define GEMM_BETA BGEMM_BETA
709714
#define GEMM_KERNEL_N BGEMM_KERNEL
710715
#define GEMM_KERNEL_L BGEMM_KERNEL
@@ -754,8 +759,8 @@
754759
#define D_BF16_TO_K DBF16TOD_K
755760
#define S_TO_BF16_K SBSTOBF16_K
756761
#define S_BF16_TO_K SBF16TOS_K
757-
#define SBGEMV_N SBGEMV_N_K
758-
#define SBGEMV_T SBGEMV_T_K
762+
#define GEMV_N SBGEMV_N_K
763+
#define GEMV_T SBGEMV_T_K
759764

760765
#define AMAX_K SAMAX_K
761766
#define AMIN_K SAMIN_K
@@ -773,8 +778,6 @@
773778
#define AXPYC_K SAXPYC_K
774779
#define AXPBY_K SAXPBY_K
775780
#define SCAL_K SSCAL_K
776-
#define GEMV_N SGEMV_N
777-
#define GEMV_T SGEMV_T
778781
#define SYMV_U SSYMV_U
779782
#define SYMV_L SSYMV_L
780783
#define GERU_K SGERU_K

common_param.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,14 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
9898
int (*sbrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float);
9999
int (*sbrotm_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
100100

101+
int (*bscal_k) (BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
101102
int (*sbaxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
102103
int (*sbscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
103104
int (*sbswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
104105

106+
int (*bgemv_n) (BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG);
107+
int (*bgemv_t) (BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG);
108+
105109
int (*sbgemv_n) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
106110
int (*sbgemv_t) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
107111
int (*sbger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);

0 commit comments

Comments
 (0)