Skip to content

Commit c44664b

Browse files
committed
Always favor fp16 arithmetic in tinyBLAS
It was assumed earlier that upcasting would help precision. However this wasn't the case, according to levenshtein distance in whisperfile output which tells me this change makes things objectively better in noticeable ways. So we now avoid the fp16 conversion, when the ISA is available. It should be perfectly safe and accurate, even for large sums, since we now have the ruler reduction divide and conquer approach, in tinyBLAS::gemm.
1 parent 6287b60 commit c44664b

File tree

9 files changed

+42
-74
lines changed

9 files changed

+42
-74
lines changed

llama.cpp/ggml.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11091,8 +11091,7 @@ static void ggml_compute_forward_mul_mat(
1109111091
ith, nth,
1109211092
src0->type,
1109311093
src1->type,
11094-
dst->type,
11095-
dst->op_params[0]))
11094+
dst->type))
1109611095
goto UseGgmlGemm1;
1109711096
return;
1109811097
}
@@ -11153,8 +11152,7 @@ UseGgmlGemm1:;
1115311152
ith, nth,
1115411153
src0->type,
1115511154
vec_dot_type,
11156-
dst->type,
11157-
dst->op_params[0]))
11155+
dst->type))
1115811156
goto UseGgmlGemm2;
1115911157
return;
1116011158
}

llamafile/sgemm.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,11 @@ static const struct GemmFuncs {
123123
* @param Atype is GGML data type of `A`
124124
* @param Btype is GGML data type of `B`
125125
* @param Ctype is GGML data type of `C`
126-
* @param precision may be used to control the internal compute type
127126
* @return true if this function was able to service the matmul request
128127
*/
129128
bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void *B, long ldb,
130-
void *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype,
131-
int precision) {
132-
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype, precision);
129+
void *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
130+
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype);
133131
}
134132

135133
/**

llamafile/sgemm.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,30 @@ bool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void *, const
2121
long, long, const void *, int, int);
2222

2323
bool llamafile_sgemm(long, long, long, const void *, long, const void *, long, void *, long, int,
24-
int, int, int, int, int);
24+
int, int, int, int);
2525
bool llamafile_mixmul(const struct ggml_compute_params *, const struct ggml_tensor *,
2626
const struct ggml_tensor *, const struct ggml_tensor *, struct ggml_tensor *);
2727
size_t llamafile_mixmul_needs(const struct ggml_tensor *, const struct ggml_tensor *,
2828
const struct ggml_tensor *);
2929

3030
bool llamafile_sgemm_unsupported(long, long, long, const void *, long, const void *, long, void *,
31-
long, int, int, int, int, int, int);
31+
long, int, int, int, int, int);
3232
bool llamafile_sgemm_amd_avx(long, long, long, const void *, long, const void *, long, void *, long,
33-
int, int, int, int, int, int);
33+
int, int, int, int, int);
3434
bool llamafile_sgemm_amd_fma(long, long, long, const void *, long, const void *, long, void *, long,
35-
int, int, int, int, int, int);
35+
int, int, int, int, int);
3636
bool llamafile_sgemm_amd_avx2(long, long, long, const void *, long, const void *, long, void *,
37-
long, int, int, int, int, int, int);
37+
long, int, int, int, int, int);
3838
bool llamafile_sgemm_amd_avxvnni(long, long, long, const void *, long, const void *, long, void *,
39-
long, int, int, int, int, int, int);
39+
long, int, int, int, int, int);
4040
bool llamafile_sgemm_amd_avx512f(long, long, long, const void *, long, const void *, long, void *,
41-
long, int, int, int, int, int, int);
41+
long, int, int, int, int, int);
4242
bool llamafile_sgemm_amd_zen4(long, long, long, const void *, long, const void *, long, void *,
43-
long, int, int, int, int, int, int);
43+
long, int, int, int, int, int);
4444
bool llamafile_sgemm_arm80(long, long, long, const void *, long, const void *, long, void *, long,
45-
int, int, int, int, int, int);
45+
int, int, int, int, int);
4646
bool llamafile_sgemm_arm82(long, long, long, const void *, long, const void *, long, void *, long,
47-
int, int, int, int, int, int);
47+
int, int, int, int, int);
4848

4949
bool llamafile_mixmul_unsupported(const struct ggml_compute_params *, const struct ggml_tensor *,
5050
const struct ggml_tensor *, const struct ggml_tensor *,

llamafile/sgemm_matmul_test.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,11 @@
3232
int cpu_get_num_math();
3333

3434
void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
35-
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype,
36-
int precision) {
35+
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype) {
3736
static int nth = cpu_get_num_math();
3837
#pragma omp parallel for
3938
for (int ith = 0; ith < nth; ++ith) {
40-
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype,
41-
precision);
39+
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype);
4240
assert(res);
4341
}
4442
}
@@ -63,7 +61,7 @@ int test(void) {
6361

6462
BENCH(ansiBLAS::sgemm(m, n, k, A, lda, B, ldb, G, ldc));
6563
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TYPE_F32, GGML_TYPE_F32,
66-
GGML_TYPE_F32, GGML_PREC_DEFAULT));
64+
GGML_TYPE_F32));
6765

6866
int flips = 0;
6967
double err_sum = 0;

llamafile/sgemm_sss_test.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@
3131
int cpu_get_num_math();
3232

3333
void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
34-
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype,
35-
int precision) {
34+
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype) {
3635
static int nth = cpu_get_num_math();
3736
#pragma omp parallel for
3837
for (int ith = 0; ith < nth; ++ith) {
39-
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype,
40-
precision);
38+
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype);
4139
assert(res);
4240
}
4341
}
@@ -63,7 +61,7 @@ int test(void) {
6361

6462
BENCH(ansiBLAS::sgemm(m, n, k, A, lda, B, ldb, G, ldc));
6563
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TYPE_F32, GGML_TYPE_F32,
66-
GGML_TYPE_F32, GGML_PREC_DEFAULT));
64+
GGML_TYPE_F32));
6765

6866
double err_sum = 0;
6967
long long err_worst = 0;

llamafile/sgemm_vecdot_test.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,11 @@
3030
int cpu_get_num_math();
3131

3232
void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
33-
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype,
34-
int precision) {
33+
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype) {
3534
static int nth = cpu_get_num_math();
3635
#pragma omp parallel for
3736
for (int ith = 0; ith < nth; ++ith) {
38-
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype,
39-
precision);
37+
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype);
4038
assert(res);
4139
}
4240
}
@@ -61,7 +59,7 @@ int test(void) {
6159

6260
BENCH(ansiBLAS::sgemm(m, n, k, A, lda, B, ldb, G, ldc));
6361
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TYPE_F32, GGML_TYPE_F32,
64-
GGML_TYPE_F32, GGML_PREC_DEFAULT));
62+
GGML_TYPE_F32));
6563

6664
double err_sum = 0;
6765
long long err_worst = 0;

llamafile/tinyblas_cpu_mixmul.inc

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,17 +224,10 @@ class MixMul {
224224
return false;
225225
}
226226
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
227-
if (result->op_params[0] == GGML_PREC_F32) {
228-
return mixmat<
229-
4, 1,
230-
tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, TC>,
231-
ggml_fp16_t, ggml_fp16_t, TC>();
232-
} else {
233-
return mixmat<
234-
8, 1,
235-
tinyBLAS<NCB | NCC, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC>,
236-
ggml_fp16_t, ggml_fp16_t, TC>();
237-
}
227+
return mixmat<
228+
8, 1,
229+
tinyBLAS<NCB | NCC, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC>,
230+
ggml_fp16_t, ggml_fp16_t, TC>();
238231
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
239232
return mixmat<
240233
4, 1,

llamafile/tinyblas_cpu_sgemm.inc

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@
3434
// have excellent performance[1] for matrices that fit in the CPU cache
3535
// without imposing any overhead such as cache filling or malloc calls.
3636
//
37-
// This implementation does not guarantee any upper bound with rounding
38-
// errors, which grow along with k. Our goal's to maximally exploit the
39-
// hardware for performance, and then use whatever resources remain for
40-
// improving numerical accuracy.
37+
// With the F32, F16, and BF16 data types, the accumulation of roundoff
38+
// errors will only grow logarithmically, thanks to the ruler function.
4139
//
4240
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
4341
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
@@ -46,8 +44,7 @@ namespace {
4644

4745
template <typename TC>
4846
bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const void *B, long ldb,
49-
TC *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype,
50-
int precision) {
47+
TC *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
5148

5249
switch (Atype) {
5350

@@ -160,23 +157,14 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
160157
if (n < 2)
161158
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
162159
return NOT_PROFITABLE;
163-
if (precision == GGML_PREC_F32) {
164-
if (Btype != GGML_TYPE_F32)
165-
return NOT_SUPPORTED;
166-
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
167-
k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth};
168-
tb.matmul(m, n);
169-
return true;
170-
} else {
171-
if (Btype == GGML_TYPE_F32)
172-
return WANT_QUANTIZATION;
173-
if (Btype != GGML_TYPE_F16)
174-
return NOT_SUPPORTED;
175-
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
176-
k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, C, ldc, ith, nth};
177-
tb.matmul(m, n);
178-
return true;
179-
}
160+
if (Btype == GGML_TYPE_F32)
161+
return WANT_QUANTIZATION;
162+
if (Btype != GGML_TYPE_F16)
163+
return NOT_SUPPORTED;
164+
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
165+
k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, C, ldc, ith, nth};
166+
tb.matmul(m, n);
167+
return true;
180168
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
181169
if (n < 2 && !FLAG_precise)
182170
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
@@ -249,7 +237,6 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
249237
(void)nth;
250238
(void)Atype;
251239
(void)Btype;
252-
(void)precision;
253240
}
254241

255242
} // namespace
@@ -265,8 +252,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
265252
* For example, for single-threaded single-precision GEMM you can say
266253
*
267254
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
268-
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
269-
* GGML_PREC_DEFAULT);
255+
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
270256
*
271257
* @param m is rows in `A` and `C`
272258
* @param n is cols in `B` and `C`
@@ -286,8 +272,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
286272
* @return true if this function was able to service the matmul request
287273
*/
288274
bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void *B, long ldb,
289-
void *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype,
290-
int precision) {
275+
void *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
291276

292277
assert(m >= 0);
293278
assert(n >= 0);
@@ -339,7 +324,7 @@ bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void
339324
switch (Ctype) {
340325
case GGML_TYPE_F32:
341326
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float *)C, ldc, ith, nth, Atype,
342-
Btype, Ctype, precision);
327+
Btype, Ctype);
343328
default:
344329
return NOT_SUPPORTED;
345330
}

llamafile/tinyblas_cpu_unsupported.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
bool llamafile_sgemm_unsupported(long m, long n, long k, const void *A, long lda, const void *B,
2121
long ldb, void *C, long ldc, int ith, int nth, int Atype,
22-
int Btype, int Ctype, int precision) {
22+
int Btype, int Ctype) {
2323
return false;
2424
}
2525

0 commit comments

Comments
 (0)