Skip to content

Commit 91c84e1

Browse files
authored
Merge pull request #2796 from Guobing-Chen/BF16_dot_coversion_apis
Add bfloat16 based dot and conversion with single/double
2 parents 1ee1e7b + deaeb6c commit 91c84e1

31 files changed

+1392
-82
lines changed

Makefile.tail

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
55
CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
66
ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
77
XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
8+
SHEXTOBJS_P = $(SHEXTOBJS:.$(SUFFIX)=.$(PSUFFIX))
89

910
COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
1011

1112
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
1213

13-
BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
14-
BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
14+
BLASOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
15+
BLASOBJS_P = $(SHEXTOBJS_P) $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
1516

1617
ifdef EXPRECISION
1718
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -30,6 +31,7 @@ $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX
3031
$(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX
3132
$(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
3233
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
34+
$(SHEXTOBJS) $(SHEXTOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX
3335

3436
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3537
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
@@ -38,6 +40,7 @@ $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3840
$(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
3941
$(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4042
$(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
43+
$(SHEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4144

4245
libs :: $(BLASOBJS) $(COMMONOBJS)
4346
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^

cblas.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint
382382
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta,
383383
double *c, OPENBLAS_CONST blasint cldc);
384384

385+
/*** BFLOAT16 and INT8 extensions ***/
386+
/* convert float array to BFLOAT16 array by rounding */
387+
void cblas_shstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
388+
/* convert double array to BFLOAT16 array by rounding */
389+
void cblas_shdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
390+
/* convert BFLOAT16 array to float array */
391+
void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout);
392+
/* convert BFLOAT16 array to double array */
393+
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
394+
/* dot production of BFLOAT16 input arrays, and output as float */
395+
float cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
385396

386397
#ifdef __cplusplus
387398
}

cmake/kernel.cmake

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,14 @@ if (BUILD_HALF)
126126
set(SHAXPYKERNEL ../arm/axpy.c)
127127
set(SHAXPBYKERNEL ../arm/axpby.c)
128128
set(SHCOPYKERNEL ../arm/copy.c)
129-
set(SHDOTKERNEL ../arm/dot.c)
129+
set(SHDOTKERNEL ../x86_64/shdot.c)
130130
set(SHROTKERNEL ../arm/rot.c)
131131
set(SHSCALKERNEL ../arm/scal.c)
132132
set(SHNRM2KERNEL ../arm/nrm2.c)
133133
set(SHSUMKERNEL ../arm/sum.c)
134134
set(SHSWAPKERNEL ../arm/swap.c)
135+
set(TOBF16KERNEL ../x86_64/tobf16.c)
136+
set(BF16TOKERNEL ../x86_64/bf16to.c)
135137
endif ()
136138
endmacro ()
137139

common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ typedef unsigned long BLASULONG;
258258
#endif
259259

260260
#ifndef BFLOAT16
261-
typedef unsigned short bfloat16;
261+
#include <stdint.h>
262+
typedef uint16_t bfloat16;
262263
#define HALFCONVERSION 1
263264
#endif
264265

common_interface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *);
5454
double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *);
5555
xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *);
5656

57+
float BLASFUNC(shdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *);
58+
void BLASFUNC(shstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *);
59+
void BLASFUNC(shdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *);
60+
void BLASFUNC(sbf16tos) (blasint *, bfloat16 *, blasint *, float *, blasint *);
61+
void BLASFUNC(dbf16tod) (blasint *, bfloat16 *, blasint *, double *, blasint *);
5762

5863
#ifdef RETURN_BY_STRUCT
5964
typedef struct {

common_level1.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ float sdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
4646
double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG);
4747
double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG);
4848
xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
49+
float shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
50+
51+
void shstobf16_k(BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
52+
void shdtobf16_k(BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);
53+
void sbf16tos_k (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
54+
void dbf16tod_k (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG);
4955

5056
openblas_complex_float cdotc_k (BLASLONG, float *, BLASLONG, float *, BLASLONG);
5157
openblas_complex_float cdotu_k (BLASLONG, float *, BLASLONG, float *, BLASLONG);

common_macro.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,11 @@
646646

647647
#elif defined(HALF)
648648

649+
#define D_TO_BF16_K SHDTOBF16_K
650+
#define D_BF16_TO_K DBF16TOD_K
651+
#define S_TO_BF16_K SHSTOBF16_K
652+
#define S_BF16_TO_K SBF16TOS_K
653+
649654
#define AMAX_K SAMAX_K
650655
#define AMIN_K SAMIN_K
651656
#define MAX_K SMAX_K
@@ -657,6 +662,7 @@
657662
#define ASUM_K SASUM_K
658663
#define DOTU_K SDOTU_K
659664
#define DOTC_K SDOTC_K
665+
#define BF16_DOT_K SHDOT_K
660666
#define AXPYU_K SAXPYU_K
661667
#define AXPYC_K SAXPYC_K
662668
#define AXPBY_K SAXPBY_K

common_param.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ typedef struct {
5151
int shgemm_p, shgemm_q, shgemm_r;
5252
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
5353

54+
void (*shstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG);
55+
void (*shdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG);
56+
void (*sbf16tos_k) (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
57+
void (*dbf16tod_k) (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG);
58+
5459
float (*shamax_k) (BLASLONG, float *, BLASLONG);
5560
float (*shamin_k) (BLASLONG, float *, BLASLONG);
5661
float (*shmax_k) (BLASLONG, float *, BLASLONG);
@@ -64,7 +69,7 @@ BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG);
6469
float (*shasum_k) (BLASLONG, float *, BLASLONG);
6570
float (*shsum_k) (BLASLONG, float *, BLASLONG);
6671
int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
67-
float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
72+
float (*shdot_k) (BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
6873
double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
6974

7075
int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float);

common_sh.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
#ifndef DYNAMIC_ARCH
55

6+
#define SHDOT_K shdot_k
7+
#define SHSTOBF16_K shstobf16_k
8+
#define SHDTOBF16_K shdtobf16_k
9+
#define SBF16TOS_K sbf16tos_k
10+
#define DBF16TOD_K dbf16tod_k
11+
612
#define SHGEMM_ONCOPY shgemm_oncopy
713
#define SHGEMM_OTCOPY shgemm_otcopy
814

@@ -18,6 +24,12 @@
1824

1925
#else
2026

27+
#define SHDOT_K gotoblas -> shdot_k
28+
#define SHSTOBF16_K gotoblas -> shstobf16_k
29+
#define SHDTOBF16_K gotoblas -> shdtobf16_k
30+
#define SBF16TOS_K gotoblas -> sbf16tos_k
31+
#define DBF16TOD_K gotoblas -> dbf16tod_k
32+
2133
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
2234
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
2335
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy

common_thread.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,19 @@ extern int blas_omp_linked;
5959
#define BLAS_PTHREAD 0x4000U
6060
#define BLAS_NODE 0x2000U
6161

62-
#define BLAS_PREC 0x0003U
63-
#define BLAS_SINGLE 0x0000U
64-
#define BLAS_DOUBLE 0x0001U
65-
#define BLAS_XDOUBLE 0x0002U
66-
#define BLAS_REAL 0x0000U
67-
#define BLAS_COMPLEX 0x0004U
62+
#define BLAS_PREC 0x000FU
63+
#define BLAS_INT8 0x0000U
64+
#define BLAS_BFLOAT16 0x0001U
65+
#define BLAS_SINGLE 0x0002U
66+
#define BLAS_DOUBLE 0x0003U
67+
#define BLAS_XDOUBLE 0x0004U
68+
#define BLAS_STOBF16 0x0008U
69+
#define BLAS_DTOBF16 0x0009U
70+
#define BLAS_BF16TOS 0x000AU
71+
#define BLAS_BF16TOD 0x000BU
72+
73+
#define BLAS_REAL 0x0000U
74+
#define BLAS_COMPLEX 0x1000U
6875

6976
#define BLAS_TRANSA 0x0030U /* 2bit */
7077
#define BLAS_TRANSA_N 0x0000U

0 commit comments

Comments
 (0)