Skip to content

Commit fa4c4e7

Browse files
committed
Use bf16 kv cache when it's faster
1 parent 98eff09 commit fa4c4e7

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

llama.cpp/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,6 +2159,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
21592159
if (s == "f32") {
21602160
return GGML_TYPE_F32;
21612161
}
2162+
if (s == "bf16") {
2163+
return GGML_TYPE_BF16;
2164+
}
21622165
if (s == "f16") {
21632166
return GGML_TYPE_F16;
21642167
}

llama.cpp/common.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <thread>
2323
#include <unordered_map>
2424
#include <tuple>
25+
#include <cosmo.h>
2526

2627
#ifdef _WIN32
2728
#define DIRECTORY_SEPARATOR '\\'
@@ -192,8 +193,8 @@ struct gpt_params {
192193
bool warmup = true; // warmup run
193194
bool check_tensors = false; // validate tensor data
194195

195-
std::string cache_type_k = "f16"; // KV cache data type for the K
196-
std::string cache_type_v = "f16"; // KV cache data type for the V
196+
std::string cache_type_k = X86_HAVE(AVX512_BF16) ? "bf16" : "f16"; // KV cache data type for the K [jart]
197+
std::string cache_type_v = X86_HAVE(AVX512_BF16) ? "bf16" : "f16"; // KV cache data type for the V [jart]
197198

198199
// multimodal models (see examples/llava)
199200
std::string mmproj = ""; // path to multimodal projector

llama.cpp/llama.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16766,7 +16766,10 @@ struct llama_context * llama_new_context_with_model(
1676616766
params.flash_attn = false;
1676716767
}
1676816768

16769-
if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
16769+
// [jart] allow bf16
16770+
if (params.type_v != GGML_TYPE_F16 &&
16771+
params.type_v != GGML_TYPE_BF16 &&
16772+
!params.flash_attn) {
1677016773
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
1677116774
return nullptr;
1677216775
}

llamafile/tinyblas_cpu_sgemm.inc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
7373

7474
case GGML_TYPE_BF16: {
7575
#if defined(__AVX512BF16__)
76-
if (Btype == GGML_TYPE_F32 && n < 2) {
76+
if (Btype == GGML_TYPE_F32 && n <= 2) {
7777
tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{
7878
k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth};
7979
tb.matmul(m, n);
@@ -120,7 +120,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
120120

121121
case GGML_TYPE_F16: {
122122
#if defined(__AVX512F__)
123-
if (Btype == GGML_TYPE_F32 && n < 2) {
123+
if (Btype == GGML_TYPE_F32 && n <= 2) {
124124
tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{
125125
k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth};
126126
tb.matmul(m, n);
@@ -136,7 +136,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
136136
return true;
137137
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
138138
if (X86_CHECK(F16C)) {
139-
if (Btype == GGML_TYPE_F32 && n < 2) {
139+
if (Btype == GGML_TYPE_F32 && n <= 2) {
140140
tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{
141141
k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth};
142142
tb.matmul(m, n);

0 commit comments

Comments
 (0)