@@ -6970,6 +6970,9 @@ struct QFBase {
6970
6970
using Acc = __m256;
6971
6971
static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }
6972
6972
static inline Data load(const float * x) { return _mm256_loadu_ps(x); }
6973
+ static inline Data load(const ggml_bf16_t * x) {
6974
+ return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));
6975
+ }
6973
6976
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
6974
6977
return _mm256_fmadd_ps(y, x, prev);
6975
6978
}
@@ -7003,6 +7006,9 @@ struct QFBase {
7003
7006
#endif
7004
7007
static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }
7005
7008
static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }
7009
+ static inline __m128 load128(const ggml_bf16_t * x) {
7010
+ return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));
7011
+ }
7006
7012
};
7007
7013
template <typename Float, int nrc_in> struct QFT final : public QFBase {
7008
7014
constexpr static int nrc = nrc_in;
@@ -7142,7 +7148,7 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
7142
7148
#ifdef __AVX512F__
7143
7149
constexpr int k_nx = 5;
7144
7150
#else
7145
- constexpr int k_nx = 2;
7151
+ constexpr int k_nx = nrc_y == 1 ? 4 : 2;
7146
7152
#endif
7147
7153
const char * cx = (const char *)vx;
7148
7154
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
@@ -7151,14 +7157,26 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
7151
7157
int last_x = k_nx*(nrc_x/k_nx);
7152
7158
if (last_x == nrc_x) return;
7153
7159
int nx = nrc_x - last_x;
7160
+ #ifdef __AVX512F__
7154
7161
switch (nx) {
7155
7162
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
7156
- #ifdef __AVX512F__
7157
7163
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
7158
7164
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
7159
7165
case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
7160
- #endif
7161
7166
}
7167
+ #else
7168
+ if constexpr (nrc_y == 1) {
7169
+ switch (nx) {
7170
+ case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
7171
+ case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
7172
+ case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
7173
+ }
7174
+ } else {
7175
+ switch (nx) {
7176
+ case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
7177
+ }
7178
+ }
7179
+ #endif
7162
7180
}
7163
7181
7164
7182
#ifdef __AVX512BF16__
@@ -7456,6 +7474,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
7456
7474
switch (typeB) {
7457
7475
#ifdef __AVX512BF16__
7458
7476
case GGML_TYPE_BF16: set_mul_mat_bf16(mm); break;
7477
+ #else
7478
+ case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(mm); break;
7479
+ case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(mm); break;
7459
7480
#endif
7460
7481
default: return false;
7461
7482
}
0 commit comments