Skip to content

Commit dbf5d31

Browse files
ikawrakowIwan Kawrakow
andauthored
Better BF16 support on AVX2 (#175)
* Adding BF16 support for AVX2 PP performance is the same as fp16 (~153 t/s on Ryzen-5975WX), but TG is quite a bit lower (3.65 t/s vs 4.72 t/s at 8 threads). Why? * Slightly faster fp16/bf16 gemv on AVX2 It still saturates at the same lower peformance for bf16 --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 6d23495 commit dbf5d31

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6970,6 +6970,9 @@ struct QFBase {
69706970
using Acc = __m256;
69716971
static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }
69726972
static inline Data load(const float * x) { return _mm256_loadu_ps(x); }
6973+
static inline Data load(const ggml_bf16_t * x) {
6974+
return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));
6975+
}
69736976
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
69746977
return _mm256_fmadd_ps(y, x, prev);
69756978
}
@@ -7003,6 +7006,9 @@ struct QFBase {
70037006
#endif
70047007
static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }
70057008
static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }
7009+
static inline __m128 load128(const ggml_bf16_t * x) {
7010+
return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));
7011+
}
70067012
};
70077013
template <typename Float, int nrc_in> struct QFT final : public QFBase {
70087014
constexpr static int nrc = nrc_in;
@@ -7142,7 +7148,7 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
71427148
#ifdef __AVX512F__
71437149
constexpr int k_nx = 5;
71447150
#else
7145-
constexpr int k_nx = 2;
7151+
constexpr int k_nx = nrc_y == 1 ? 4 : 2;
71467152
#endif
71477153
const char * cx = (const char *)vx;
71487154
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
@@ -7151,14 +7157,26 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
71517157
int last_x = k_nx*(nrc_x/k_nx);
71527158
if (last_x == nrc_x) return;
71537159
int nx = nrc_x - last_x;
7160+
#ifdef __AVX512F__
71547161
switch (nx) {
71557162
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
7156-
#ifdef __AVX512F__
71577163
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
71587164
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
71597165
case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
7160-
#endif
71617166
}
7167+
#else
7168+
if constexpr (nrc_y == 1) {
7169+
switch (nx) {
7170+
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
7171+
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
7172+
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
7173+
}
7174+
} else {
7175+
switch (nx) {
7176+
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
7177+
}
7178+
}
7179+
#endif
71627180
}
71637181

71647182
#ifdef __AVX512BF16__
@@ -7456,6 +7474,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
74567474
switch (typeB) {
74577475
#ifdef __AVX512BF16__
74587476
case GGML_TYPE_BF16: set_mul_mat_bf16(mm); break;
7477+
#else
7478+
case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(mm); break;
7479+
case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(mm); break;
74597480
#endif
74607481
default: return false;
74617482
}

0 commit comments

Comments
 (0)