From 0f301124b11da9bbd93a8655b7ad468851346d1c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 27 Aug 2024 19:08:31 +0300 Subject: [PATCH 01/13] WIP: play with KQ mask - make it fp16 --- src/llama.cpp | 146 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 106 insertions(+), 40 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 8a85144ef..76aa3fb8a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8687,25 +8687,40 @@ struct llm_build_context { } struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { + auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16; lctx.inp_KQ_mask = causal - ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) + : ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); + return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; + //lctx.inp_KQ_mask = causal + // ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) + // : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(lctx.inp_KQ_mask, "KQ_mask", -1); + //ggml_set_input(lctx.inp_KQ_mask); - return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; + //return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) { GGML_ASSERT(hparams.n_swa > 0); + auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16; lctx.inp_KQ_mask_swa = causal - ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) + : ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); ggml_set_input(lctx.inp_KQ_mask_swa); - return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa; + return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa; + //lctx.inp_KQ_mask_swa = causal + // ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) + // : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); + //ggml_set_input(lctx.inp_KQ_mask_swa); + + //return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa; } struct ggml_tensor * build_inp_mean() { @@ -14259,71 +14274,122 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; - float * data = nullptr; float * data_swa = nullptr; + if (lctx.inp_KQ_mask && lctx.inp_KQ_mask_swa) { + GGML_ASSERT(lctx.inp_KQ_mask->type == lctx.inp_KQ_mask_swa->type); + } + if (lctx.inp_KQ_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - data = (float *) lctx.inp_KQ_mask->data; } - if (lctx.inp_KQ_mask_swa) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); - data_swa = (float *) lctx.inp_KQ_mask_swa->data; } - // For causal attention, use only the previous KV cells - // of the correct sequence for each token of the batch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - for (int h = 0; h < 1; ++h) { + auto float_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; + GGML_ASSERT(float_type == GGML_TYPE_F16 || float_type == GGML_TYPE_F32); + + if (float_type == GGML_TYPE_F16) { + // in order this to be true, we are not using alibi + GGML_ASSERT(!hparams.use_alibi); + auto h_zero = ggml_fp32_to_fp16(0.0f); + auto h_inf = ggml_fp32_to_fp16(-INFINITY); + ggml_fp16_t * h_data = lctx.inp_KQ_mask ? (ggml_fp16_t *)lctx.inp_KQ_mask->data : nullptr; + ggml_fp16_t * h_data_swa = lctx.inp_KQ_mask_swa ? (ggml_fp16_t *)lctx.inp_KQ_mask_swa->data : nullptr; for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { - float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -std::abs(lctx.kv_self.cells[i].pos - pos); - } else { - f = 0.0f; - } + auto f = lctx.kv_self.cells[i].pos <= pos && lctx.kv_self.cells[i].has_seq_id(seq_id) ? h_zero : h_inf; + if (h_data) h_data[j*n_kv + i] = f; + if (h_data_swa) { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) f = h_inf; + h_data_swa[j*n_kv + i] = f; } + } + } - if (data) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; - } + if (h_data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) h_data[i*n_kv + j] = h_inf; + } + } + + if (h_data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) h_data_swa[i*n_kv + j] = h_inf; + } + } + } + + else { - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + if (lctx.inp_KQ_mask) { + data = (float *) lctx.inp_KQ_mask->data; + } + + if (lctx.inp_KQ_mask_swa) { + data_swa = (float *) lctx.inp_KQ_mask_swa->data; + } + + // For causal attention, use only the previous KV cells + // of the correct sequence for each token of the batch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(lctx.kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } - data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } - } - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } } } - } - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } } } } } } else { + // TODO + GGML_ASSERT(false); + // when using kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; From 16b8d3d229446dd52b6e48bf77bb47035aebfb43 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 27 Aug 2024 19:45:57 +0300 Subject: [PATCH 02/13] WIP --- ggml/src/ggml.c | 54 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index cebac5845..adff76699 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2043,6 +2043,38 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + __m512 vslope = _mm512_set1_ps(slope); + __m512 vmax = _mm512_set1_ps(-INFINITY); + for (int j = 0; j < n/16; ++j) { + __m512 v = _mm512_fmadd_ps(vslope, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x + j)), _mm512_loadu_ps(y + 16*j)); + _mm512_storeu_ps(y + 16*j, v); + vmax = _mm512_max_ps(vmax, v); + } + float max = _mm512_reduce_max_ps(vmax); + for (int i = 16*(n/16); i < n; ++i) { + y[i] += slope*GGML_FP16_TO_FP32(x[i]); + max = MAX(max, y[i]); + } + return max; +} + +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + __m512 vslope = _mm512_set1_ps(slope); + __m512 vmax = _mm512_set1_ps(-INFINITY); + for (int j = 0; j < n/16; ++j) { + __m512 v = _mm512_fmadd_ps(vslope, _mm512_loadu_ps(x + 16*j), _mm512_loadu_ps(y + 16*j)); + _mm512_storeu_ps(y + 16*j, v); + vmax = _mm512_max_ps(vmax, v); + } + float max = _mm512_reduce_max_ps(vmax); + for (int i = 16*(n/16); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; +} + static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -13782,17 +13814,23 @@ static void ggml_compute_forward_softcap_max_f32( ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + float max = -INFINITY; if (mp_f32) { if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - } + max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); + //for (int i = 0; i < nc; ++i) { + // wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + //} } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; - } + max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); + //for (int i = 0; i < nc; ++i) { + // wp[i] += slope*mp_f32[i]; + //} } } + else { + ggml_vec_max_f32(nc, &max, wp); + } #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -13801,8 +13839,8 @@ static void ggml_compute_forward_softcap_max_f32( } #endif - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + //float max = -INFINITY; + //ggml_vec_max_f32(nc, &max, wp); ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); assert(sum > 0.0); From 511c4592320270bd04babab9642e5bc7af96326b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 09:08:49 +0300 Subject: [PATCH 03/13] WIP: play with KQ mask - make it binary Here we get a small speedup: Gemma-2-2b and 32k context is ~4% faster on Zen4. But on Zen4 we can use _mm512_mask_mul_ps(-inifnity, mask, s_after, tanh(x*s_before)) to scale and apply mask in a single op that has the same latency and throughput as _mm512_mul_ps. Combined with reducing memory loads for the mask represented as fp32 (or fp16), this gives us some performance improvement for very large masks (contexts). It will be much more tricky on the other platforms that do not have masked instructions. --- ggml/src/ggml.c | 84 +++++++++++++++++++++++++++++++++++++------------ src/llama.cpp | 64 +++++++++++++++++++------------------ 2 files changed, 98 insertions(+), 50 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index adff76699..755d34cec 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2075,6 +2075,19 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y return max; } +static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { + GGML_ASSERT(n%16 == 0); + __m512 vmax = _mm512_set1_ps(-INFINITY); + __m512 vinf = _mm512_set1_ps(-INFINITY); + const __mmask16 * mm16 = (const __mmask16 *)x; + for (int j = 0; j < n/16; ++j) { + __m512 v = _mm512_mask_blend_ps(mm16[j], _mm512_loadu_ps(y + 16*j), vinf); + _mm512_storeu_ps(y + 16*j, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} + static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -2646,6 +2659,13 @@ inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) { return _mm512_mul_ps(th, s_after); } +inline static __m512 ggml_v_softcap_mask(__m512 x, __m512 s_before, __m512 s_after, __m512 src, __mmask16 mask) { + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, s_before)); + const __m512 th = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mask_mul_ps(src, mask, th, s_after); +} + inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) { const __m512 one = _mm512_set1_ps(1.0f); __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); @@ -2883,6 +2903,20 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl } } +static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { + const __mmask16 * m16 = (const __mmask16 *)mask; + __m512 vinf = _mm512_set1_ps(-INFINITY); + __m512 vmax = vinf; + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + for (int i = 0; i < n/16; ++i) { + __m512 v = ggml_v_softcap_mask(_mm512_loadu_ps(x + 16*i), vs_before, vs_after, vinf, m16[i]); + _mm512_storeu_ps(y + 16*i, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} + static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -6045,10 +6079,10 @@ static struct ggml_tensor * ggml_softcap_max_impl( GGML_ASSERT(ggml_is_padded_1d(a)); if (mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32 || mask->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(mask->ne[0] == a->ne[0]); + //GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); } @@ -13799,6 +13833,7 @@ static void ggml_compute_forward_softcap_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32); for (int i1 = ir0; i1 < ir1; i1++) { // ALiBi @@ -13809,27 +13844,36 @@ static void ggml_compute_forward_softcap_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - - ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + const int mask_row = i1%ne01; float max = -INFINITY; - if (mp_f32) { - if (use_f16) { - max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); - //for (int i = 0; i < nc; ++i) { - // wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - //} - } else { - max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); - //for (int i = 0; i < nc; ++i) { - // wp[i] += slope*mp_f32[i]; - //} + if (use_i32) { + int n32 = (ne00 + 31)/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; + max = ggml_vec_cpy_softcap_mask_f32(nc, sp, wp, mp_u32, values[2], values[0]*values[3]); + } else { + + ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + + if (src1) { + if (use_f16) { + ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + mask_row*ne00; + max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); + } else if (use_i32) { + int n32 = (ne00 + 31)/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; + max = ggml_vec_add_f32_infmask(nc, mp_u32, wp); + } else { + float * mp_f32 = (float *)((char *) src1->data) + mask_row*ne00; + max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); + //for (int i = 0; i < nc; ++i) { + // wp[i] += slope*mp_f32[i]; + //} + } + } + else { + ggml_vec_max_f32(nc, &max, wp); } - } - else { - ggml_vec_max_f32(nc, &max, wp); } #ifndef NDEBUG diff --git a/src/llama.cpp b/src/llama.cpp index 76aa3fb8a..51b5bbaf2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8687,10 +8687,10 @@ struct llm_build_context { } struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { - auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16; - lctx.inp_KQ_mask = causal - ? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_I32; + auto nx = causal ? n_kv : n_tokens; + if (type == GGML_TYPE_I32) nx = (nx + 31)/32; + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; @@ -8705,11 +8705,10 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) { GGML_ASSERT(hparams.n_swa > 0); - - auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16; - lctx.inp_KQ_mask_swa = causal - ? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_I32; + auto nx = causal ? n_kv : n_tokens; + if (type == GGML_TYPE_I32) nx = (nx + 31)/32; + lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); ggml_set_input(lctx.inp_KQ_mask_swa); @@ -14288,40 +14287,45 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); } - auto float_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; - GGML_ASSERT(float_type == GGML_TYPE_F16 || float_type == GGML_TYPE_F32); + auto mask_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; + GGML_ASSERT(mask_type == GGML_TYPE_I32 || mask_type == GGML_TYPE_F32); - if (float_type == GGML_TYPE_F16) { + if (mask_type == GGML_TYPE_I32) { // in order this to be true, we are not using alibi GGML_ASSERT(!hparams.use_alibi); - auto h_zero = ggml_fp32_to_fp16(0.0f); - auto h_inf = ggml_fp32_to_fp16(-INFINITY); - ggml_fp16_t * h_data = lctx.inp_KQ_mask ? (ggml_fp16_t *)lctx.inp_KQ_mask->data : nullptr; - ggml_fp16_t * h_data_swa = lctx.inp_KQ_mask_swa ? (ggml_fp16_t *)lctx.inp_KQ_mask_swa->data : nullptr; + uint32_t * h_data = lctx.inp_KQ_mask ? (uint32_t *)lctx.inp_KQ_mask->data : nullptr; + uint32_t * h_data_swa = lctx.inp_KQ_mask_swa ? (uint32_t *)lctx.inp_KQ_mask_swa->data : nullptr; for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; + uint32_t u = 0, u_swa = 0; + uint32_t m = 1; + for (int i = 0; i < n_kv; ++i) { - auto f = lctx.kv_self.cells[i].pos <= pos && lctx.kv_self.cells[i].has_seq_id(seq_id) ? h_zero : h_inf; - if (h_data) h_data[j*n_kv + i] = f; - if (h_data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) f = h_inf; - h_data_swa[j*n_kv + i] = f; + if (lctx.kv_self.cells[i].pos > pos || !lctx.kv_self.cells[i].has_seq_id(seq_id)) { + u |= m; u_swa |= m; + } + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) u_swa |= m; + m <<= 1; + if (!m) { + if (h_data) *h_data++ = ~u; + if (h_data_swa) *h_data_swa++ = ~u_swa; + u = u_swa = 0; m = 1; } } - } - - if (h_data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) h_data[i*n_kv + j] = h_inf; + if (m > 1) { + if (h_data) *h_data++ = ~u; + if (h_data_swa) *h_data_swa++ = ~u_swa; } + } - if (h_data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) h_data_swa[i*n_kv + j] = h_inf; - } + auto n_pad = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_pad > n_tokens) { + auto n_kv_32 = (n_kv + 31)/32; + if (h_data) std::memset(h_data, 0, (n_pad - n_tokens)*n_kv_32*sizeof(uint32_t)); + if (h_data_swa) std::memset(h_data_swa, 0, (n_pad - n_tokens)*n_kv_32*sizeof(uint32_t)); } } From 1216a437194b3ee1d278b3203b06b183ed44da36 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 10:03:10 +0300 Subject: [PATCH 04/13] WIP KQ binary mask: CUDA Relatively painless to implement for soft_max and soft_cap_max. We gain 11.5% for LLaMA-8B and ~14% for Gemma-2-2b at 32k tokens. The KQ mask is prepared on the CPU and copied to the GPU, so my guess is that most of it comes from the 32X reduction in the amount of data being copied to the GPU. TODO: flash attention --- ggml/src/ggml-cuda/softmax.cu | 40 ++++++++++++++++++++++------------- ggml/src/ggml.c | 8 +++++-- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 6f3056e6d..e4a31fa26 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -2,13 +2,18 @@ #include "softmax.cuh" template -static __device__ __forceinline__ float t2f32(T val) { - return (float) val; +static __device__ __forceinline__ float mask_value(float slope, const T * mask, int iy) { + return mask ? slope * (float)mask[iy] : 0.0f; } template <> -__device__ float __forceinline__ t2f32(half val) { - return __half2float(val); +__device__ __forceinline__ float mask_value(float slope, const half * mask, int iy) { + return mask ? slope * __half2float(mask[iy]) : 0.0f; +} + +template <> +__device__ __forceinline__ float mask_value(float, const uint32_t * mask, int iy) { + return mask[iy >> 5] & (1u << (iy & 31)) ? 0.0f : -INFINITY; } template @@ -44,8 +49,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[iy]) : 0.0f) : - scale*x[ix] + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + mask_value(slope, mask, iy) : + scale*x[ix] + mask_value(slope, mask, iy); vals[col] = val; max_val = max(max_val, val); @@ -181,7 +186,7 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -194,14 +199,17 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32); - if (use_f16) { + if (use_i32) { + const uint32_t * mask = (const uint32_t *)src1_d; + soft_max_f32_cuda(src0_d, mask, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); + } + else if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } } @@ -219,7 +227,7 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -229,15 +237,17 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds memcpy(params, dst->op_params, sizeof(params)); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - //printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16); + const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32); - if (use_f16) { + if (use_i32) { + const uint32_t * mask = (const uint32_t *)src1_d; + soft_max_f32_cuda(src0_d, mask, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); + } + else if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 755d34cec..2c740f9d0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6850,10 +6850,14 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32 || mask->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(mask->ne[0] == a->ne[0]); + if (mask->type == GGML_TYPE_I32) { + GGML_ASSERT(mask->ne[0] == (a->ne[0] + 31)/32); + } else { + GGML_ASSERT(mask->ne[0] == a->ne[0]); + } GGML_ASSERT(mask->ne[1] >= a->ne[1]); } From 62d6ef28928fe53cac9c0e50c7234044a43a1022 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 10:51:59 +0300 Subject: [PATCH 05/13] WIP KQ binary mask: for now, just use fp16 when flash attention is on --- src/llama.cpp | 71 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 51b5bbaf2..2a6edf36f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8687,39 +8687,28 @@ struct llm_build_context { } struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { - auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_I32; auto nx = causal ? n_kv : n_tokens; - if (type == GGML_TYPE_I32) nx = (nx + 31)/32; + // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy + auto type = !lctx.is_encoding ? flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; + //auto type = flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32; + if (type == GGML_TYPE_I32) nx /= 32; lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; - //lctx.inp_KQ_mask = causal - // ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - // : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(lctx.inp_KQ_mask, "KQ_mask", -1); - //ggml_set_input(lctx.inp_KQ_mask); - - //return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; + return lctx.inp_KQ_mask; } struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) { GGML_ASSERT(hparams.n_swa > 0); - auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_I32; auto nx = causal ? n_kv : n_tokens; - if (type == GGML_TYPE_I32) nx = (nx + 31)/32; + // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy + auto type = !lctx.is_encoding ? flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; + if (type == GGML_TYPE_I32) nx /= 32; lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); ggml_set_input(lctx.inp_KQ_mask_swa); - return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa; - //lctx.inp_KQ_mask_swa = causal - // ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - // : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); - //ggml_set_input(lctx.inp_KQ_mask_swa); - - //return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa; + return lctx.inp_KQ_mask_swa; } struct ggml_tensor * build_inp_mean() { @@ -14273,9 +14262,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; - float * data = nullptr; - float * data_swa = nullptr; - if (lctx.inp_KQ_mask && lctx.inp_KQ_mask_swa) { GGML_ASSERT(lctx.inp_KQ_mask->type == lctx.inp_KQ_mask_swa->type); } @@ -14288,7 +14274,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } auto mask_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; - GGML_ASSERT(mask_type == GGML_TYPE_I32 || mask_type == GGML_TYPE_F32); + GGML_ASSERT(mask_type == GGML_TYPE_I32 || mask_type == GGML_TYPE_F32 || mask_type == GGML_TYPE_F16); if (mask_type == GGML_TYPE_I32) { // in order this to be true, we are not using alibi @@ -14329,8 +14315,45 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } + else if (mask_type == GGML_TYPE_F16) { + ggml_fp16_t * h_data = lctx.inp_KQ_mask ? (ggml_fp16_t *)lctx.inp_KQ_mask->data : nullptr; + ggml_fp16_t * h_data_swa = lctx.inp_KQ_mask_swa ? (ggml_fp16_t *)lctx.inp_KQ_mask_swa->data : nullptr; + ggml_fp16_t h_zero = ggml_fp32_to_fp16(0.0f); + ggml_fp16_t h_inf = ggml_fp32_to_fp16(-INFINITY); + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + ggml_fp16_t f; + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) f = h_inf; + else f = hparams.use_alibi ? ggml_fp32_to_fp16(-std::abs(lctx.kv_self.cells[i].pos - pos)) : h_zero; + if (h_data) h_data[j*n_kv + i] = f; + if (h_data_swa) h_data_swa[j*n_kv + i] = pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa ? h_inf : f; + } + } + auto n_pad = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_pad > n_tokens) { + if (h_data) { + for (int j = 0; j < n_kv; ++j) h_data[n_tokens*n_kv + j] = h_inf; + for (int i = n_tokens+1; i < n_pad; ++i) { + std::memcpy(h_data + i*n_kv, h_data + n_tokens*n_kv, n_kv*sizeof(ggml_fp16_t)); + } + } + if (h_data_swa) { + for (int j = 0; j < n_kv; ++j) h_data_swa[n_tokens*n_kv + j] = h_inf; + for (int i = n_tokens+1; i < n_pad; ++i) { + std::memcpy(h_data_swa + i*n_kv, h_data_swa + n_tokens*n_kv, n_kv*sizeof(ggml_fp16_t)); + } + } + } + } + else { + float * data = nullptr; + float * data_swa = nullptr; + if (lctx.inp_KQ_mask) { data = (float *) lctx.inp_KQ_mask->data; } From 900a39bec91a7aa8d21980957f306b5b70dd55fa Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 11:40:26 +0200 Subject: [PATCH 06/13] WIP KQ binary mask: Metal For now just soft_cap_max. On Gemma2-9b I'm observing a ~2% speedup for context of 16k tokens. --- ggml/src/ggml-metal.m | 15 ++- ggml/src/ggml-metal.metal | 205 ++++++++++++++++++++++++++++++++++++++ ggml/src/ggml.c | 42 +++++++- 3 files changed, 255 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 83bd76f9c..aa7b043e5 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -71,6 +71,8 @@ GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, + GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32, + GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -580,6 +582,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, soft_cap_max_f32_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32, soft_cap_max_u32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4, soft_cap_max_u32_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); @@ -1694,19 +1698,22 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_CAP_MAX: { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); int nth = 32; // SIMD width id pipeline = nil; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32); if (ne00%4 == 0) { while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; @@ -1715,7 +1722,9 @@ static enum ggml_status ggml_metal_graph_compute( while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f9c88a37d..58b3e6bce 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -661,6 +661,101 @@ kernel void kernel_soft_max_4( } } +kernel void kernel_soft_cap_max_u32( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant float & s_before, + constant float & s_after, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = (device const float * ) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float * pdst = (device float * ) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float lmax = -INFINITY; + + const float tot_scale = scale * s_after; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + float val = pmask[i00 >> 5] & (1u << (i00 & 31)) ? precise::tanh(s_before*psrc0[i00])*tot_scale : -INFINITY; + lmax = MAX(lmax, val); + pdst[i00] = val; + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp(pdst[i00] - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + template kernel void kernel_soft_cap_max( device const char * src0, @@ -767,6 +862,116 @@ kernel void kernel_soft_cap_max( } } +kernel void kernel_soft_cap_max_u32_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant float & s_before, + constant float & s_after, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + const float tot_scale = scale * s_after; + + // parallel max + float4 lmax4 = -INFINITY; + float4 vinf = lmax4; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + float4 val = precise::tanh(s_before*psrc4[i00])*tot_scale; + int idx = 4*i00; + uint8_t m = pmask[idx >> 5] >> (idx & 31); + bool4 m4 = { m & 1 ? true : false, m & 2 ? true : false, m & 4 ? true : false, m & 8 ? true : false }; + //bool4 m4 = ((pmask[idx >> 5] >> (idx & 31)) & 0xf) * 0x01010101; + val = select(vinf, val, m4); + //uint32_t m = pmask[idx >> 5] >> (idx & 31); + //val[0] = m & 1 ? val[0] : -INFINITY; + //val[1] = m & 2 ? val[1] : -INFINITY; + //val[2] = m & 4 ? val[2] : -INFINITY; + //val[3] = m & 8 ? val[3] : -INFINITY; + lmax4 = fmax(lmax4, val); + pdst4[i00] = val; + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp(pdst4[i00] - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + template kernel void kernel_soft_cap_max_4( device const char * src0, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2c740f9d0..09b6c0b40 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2043,6 +2043,7 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +#ifdef __AVX512F__ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { __m512 vslope = _mm512_set1_ps(slope); __m512 vmax = _mm512_set1_ps(-INFINITY); @@ -2058,7 +2059,6 @@ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float } return max; } - static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { __m512 vslope = _mm512_set1_ps(slope); __m512 vmax = _mm512_set1_ps(-INFINITY); @@ -2074,7 +2074,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y } return max; } - static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { GGML_ASSERT(n%16 == 0); __m512 vmax = _mm512_set1_ps(-INFINITY); @@ -2087,6 +2086,29 @@ static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, fl } return _mm512_reduce_max_ps(vmax); } +#else +// TODO +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(slope); + return 0.f; +} +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(slope); + return 0.f; +} +static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + return 0.f; +} +#endif static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); @@ -2903,6 +2925,7 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl } } +#ifdef __AVX512__ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { const __mmask16 * m16 = (const __mmask16 *)mask; __m512 vinf = _mm512_set1_ps(-INFINITY); @@ -2916,6 +2939,17 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * } return _mm512_reduce_max_ps(vmax); } +#else +static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(mask); + GGML_UNUSED(s_before); + GGML_UNUSED(s_after); + return 0.f; +} +#endif static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { int i = 0; @@ -13788,7 +13822,7 @@ static void ggml_compute_forward_softcap( default: { GGML_ASSERT(false); - } break; + } } } @@ -13920,7 +13954,7 @@ static void ggml_compute_forward_softcap_max( default: { GGML_ASSERT(false); - } break; + } } } From fe825ecbe4f140a332de07862cafc715c2d924f7 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 13:04:23 +0200 Subject: [PATCH 07/13] WIP KQ binary mask: Metal soft_max I need to redo this with better templates. --- ggml/src/ggml-metal.m | 15 ++- ggml/src/ggml-metal.metal | 192 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index aa7b043e5..51b223c71 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -67,6 +67,8 @@ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, @@ -578,6 +580,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32, soft_max_u32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4, soft_max_u32_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16, soft_cap_max_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction); @@ -1633,19 +1637,22 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); int nth = 32; // SIMD width id pipeline = nil; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32); if (ne00%4 == 0) { while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32_4].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; @@ -1654,7 +1661,9 @@ static enum ggml_status ggml_metal_graph_compute( while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_U32].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 58b3e6bce..8bd4e5c23 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -453,6 +453,198 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +kernel void kernel_soft_max_u32( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] = pmask[i00 >> 5] & (1u << (i00 & 31)) ? psrc0[i00]*scale : -INFINITY; + lmax = MAX(lmax, pdst[i00]); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp(pdst[i00] - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +kernel void kernel_soft_max_u32_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + int idx = 4*i00; + uint8_t m4 = pmask[idx >> 5] >> (idx & 31); + float4 val = psrc4[i00]*scale; + val[0] = m4 & 1 ? val[0] : -INFINITY; + val[1] = m4 & 2 ? val[1] : -INFINITY; + val[2] = m4 & 4 ? val[2] : -INFINITY; + val[3] = m4 & 8 ? val[3] : -INFINITY; + lmax4 = fmax(lmax4, val); + pdst4[i00] = val; + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp(pdst4[i00] - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + template kernel void kernel_soft_max( device const char * src0, From 05f95229a7eb1a0625daf87ef2ba2da7d3e8a915 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 15:01:02 +0200 Subject: [PATCH 08/13] WIP KQ binary mask: make it a parameter, turn on via command line It is a pain to implement binary mask to 32-bit value conversion on NEON and AVX2, so I decided to make the binary mask optional There is also a commented out (and not working) attempt for NEON in this commit. --- common/common.cpp | 7 ++++++ common/common.h | 1 + ggml/src/ggml.c | 55 +++++++++++++++++++++++++++++++++++++++++++++++ include/llama.h | 1 + src/llama.cpp | 10 +++++++-- 5 files changed, 72 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 3b45d0664..85baa5e2e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -808,6 +808,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.flash_attn = true; return true; } + if (arg == "-bkq" || arg == "--binary-kq") { + params.binary_kq = true; + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1442,6 +1446,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-bkq, --binary-kq", "enable binary KQ mask (default: %s)", params.binary_kq ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2265,6 +2270,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.binary_kq = params.binary_kq; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3261,6 +3267,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); + fprintf(stream, "binary_kq: %s # default: false\n", params.binary_kq ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index 50035897a..28b564717 100644 --- a/common/common.h +++ b/common/common.h @@ -173,6 +173,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool binary_kq = false; // use binary KQ mask (if allowed in the given context) bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 09b6c0b40..e45732aac 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2939,6 +2939,60 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * } return _mm512_reduce_max_ps(vmax); } +//#elif __ARM_NEON +//static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { +// //const uint16_t * mask16 = (const uint16_t *)mask; +// const uint8_t * mask8 = (const uint8_t *)mask; +// float32x4_t vinf = vdupq_n_f32(-INFINITY); +// float32x4x4_t vmax = { vinf, vinf, vinf, vinf }; +// float32x4_t vs_before = vdupq_n_f32(s_before); +// float32x4_t vs_after = vdupq_n_f32(s_after ); +// const uint8x16_t vmask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); +// //const uint8x8_t vmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201)); +// //static const uint32_t k_shuffle[8] = { 0x00000000, 0x01010101, 0x02020202, 0x03030303, +// // 0x04040404, 0x05050505, 0x06060606, 0x07070707 }; +// //const uint8x8x4_t vtab = vld1_u8_x4((const uint8_t *)k_shuffle); +// //for (int i = 0; i < n/16; ++i) { +// // float32x4x4_t vx = vld1q_f32_x4(x + 16*i); +// // uint8x8_t m1 = vceq_u8(vand_u8(vdup_n_u8(mask8[2*i+0]), vmask), vmask); +// // uint8x8_t m2 = vceq_u8(vand_u8(vdup_n_u8(mask8[2*i+1]), vmask), vmask); +// // uint8x16x4_t mk = { vcombine_u8(vqtbl1_u8(m1, vtab.val[0]), vqtbl1_u8(m1, vtab.val[1])), +// // for (int k = 0; k < 4; ++k) { +// // vx.val[k] = ggml_v_softcap(vx.val[k], vs_before, vs_after); +// // //uint8x16_t mk = vcombine(vqtbl1_u8(m1, vtab.val[k]), +// // uint8x16_t v_on = vandq_u8(vreinterpretq_u8_f32(vx.val[k]), mk); +// // uint8x16_t v_off = vandq_u8(vreinterpretq_u8_f32(vinf), mk); +// // vx.val[k] = vreinterpretq_f32_u8(vorrq_u8(v_on, v_off)); +// // vmax.val[k] = vmaxq_f32(vmax.val[k], vx.val[k]); +// // vst1q_f32(y + 16*i + 4*k, vx.val[k]); +// // } +// //} +// static const uint32_t k_shuffle[16] = { 0x00000000, 0x01010101, 0x02020202, 0x03030303, +// 0x04040404, 0x05050505, 0x06060606, 0x07070707, +// 0x08080808, 0x09090909, 0x0a0a0a0a, 0x0b0b0b0b, +// 0x0c0c0c0c, 0x0d0d0d0d, 0x0e0e0e0e, 0x0f0f0f0f}; +// const uint8x16x4_t vtab = vld1q_u8_x4((const uint8_t *)k_shuffle); +// for (int i = 0; i < n/16; ++i) { +// float32x4x4_t vx = vld1q_f32_x4(x + 16*i); +// uint8x16_t m = vcombine_u8(vdup_n_u8(mask8[2*i+0]), vdup_n_u8(mask8[2*i+1])); +// m = vceqq_u8(vandq_u8(m, vmask), vmask); +// for (int k = 0; k < 4; ++k) { +// vx.val[k] = ggml_v_softcap(vx.val[k], vs_before, vs_after); +// uint8x16_t mk = vqtbl1q_u8(m, vtab.val[k]); +// uint8x16_t v_on = vandq_u8(vreinterpretq_u8_f32(vx.val[k]), mk); +// uint8x16_t v_off = vandq_u8(vreinterpretq_u8_f32(vinf), mk); +// vx.val[k] = vreinterpretq_f32_u8(vorrq_u8(v_on, v_off)); +// vmax.val[k] = vmaxq_f32(vmax.val[k], vx.val[k]); +// vst1q_f32(y + 16*i + 4*k, vx.val[k]); +// } +// } +// float max = vmaxvq_f32(vmax.val[0]); +// for (int k = 1; k < 4; ++k) { +// float maxk = vmaxvq_f32(vmax.val[k]); +// max = MAX(max, maxk); +// } +// return max; +//} #else static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { GGML_UNUSED(n); @@ -2947,6 +3001,7 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * GGML_UNUSED(mask); GGML_UNUSED(s_before); GGML_UNUSED(s_after); + GGML_ASSERT(false); return 0.f; } #endif diff --git a/include/llama.h b/include/llama.h index a9af4c48c..dd13d657a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -340,6 +340,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool binary_kq; // whether to use binary KQ mask [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 2a6edf36f..83aac3da5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2348,6 +2348,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool binary_kq; enum llama_pooling_type pooling_type; @@ -8446,6 +8447,7 @@ struct llm_build_context { const int32_t n_ctx_orig; const bool flash_attn; + const bool binary_kq; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8495,6 +8497,7 @@ struct llm_build_context { kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), + binary_kq (cparams.binary_kq), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -8689,7 +8692,7 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { auto nx = causal ? n_kv : n_tokens; // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy - auto type = !lctx.is_encoding ? flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; + auto type = !lctx.is_encoding ? !binary_kq || flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; //auto type = flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32; if (type == GGML_TYPE_I32) nx /= 32; lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); @@ -8702,7 +8705,7 @@ struct llm_build_context { GGML_ASSERT(hparams.n_swa > 0); auto nx = causal ? n_kv : n_tokens; // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy - auto type = !lctx.is_encoding ? flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; + auto type = !lctx.is_encoding ? !binary_kq || flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; if (type == GGML_TYPE_I32) nx /= 32; lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); @@ -16727,6 +16730,7 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.binary_kq =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -16917,6 +16921,7 @@ struct llama_context * llama_new_context_with_model( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.binary_kq = params.binary_kq; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -16983,6 +16988,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: binary_kq = %d\n", __func__, cparams.binary_kq); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); From a8b762ddd9510c8ee98e7c5040a7845a82de9e6d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 15:07:09 +0200 Subject: [PATCH 09/13] Minor --- src/llama.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 83aac3da5..593706002 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16905,6 +16905,10 @@ struct llama_context * llama_new_context_with_model( return nullptr; } + if (params.binary_kq && params.flash_attn) { + LLAMA_LOG_WARN("%s: binary-KQ mask is currently not used in flash_attn\n", __func__); + } + llama_context * ctx = new llama_context(*model); const auto & hparams = model->hparams; From 316345c5353dded20f0ffa3371d02c5277e8f323 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 17:27:48 +0300 Subject: [PATCH 10/13] WIP KQ binary mask --- examples/llama-bench/llama-bench.cpp | 34 +++++++- ggml/src/ggml.c | 114 ++++++++++++++++----------- 2 files changed, 98 insertions(+), 50 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 813d7baeb..0736e3933 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -231,6 +231,7 @@ struct cmd_params { std::vector main_gpu; std::vector no_kv_offload; std::vector flash_attn; + std::vector binary_kq; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -258,6 +259,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, + /* binary_kq */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -289,6 +291,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); + printf(" -bkq, --binary-kq <0|1> (default: %s)\n", join(cmd_params_defaults.binary_kq, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -503,6 +506,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); + } else if (arg == "-bkq" || arg == "--binary-kq") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.binary_kq.insert(params.binary_kq.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -591,6 +601,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } + if (params.binary_kq.empty()) { params.binary_kq = cmd_params_defaults.binary_kq; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -614,6 +625,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool flash_attn; + bool binary_kq; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -653,6 +665,7 @@ struct cmd_params_instance { cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; + cparams.binary_kq = binary_kq; cparams.embeddings = embeddings; return cparams; @@ -677,6 +690,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) + for (const auto & bkq : params.binary_kq) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -697,6 +711,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -723,6 +738,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -749,6 +765,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -787,6 +804,7 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; + bool binary_kq; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -813,6 +831,7 @@ struct test { main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; flash_attn = inst.flash_attn; + binary_kq = inst.binary_kq; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -884,7 +903,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", + "main_gpu", "no_kv_offload", "flash_attn", "binary-kq", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -906,7 +925,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "binary-kq" || field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -940,7 +959,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(binary_kq), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1103,6 +1122,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return 2; } + if (field == "binary-kq") { + return 3; + } if (field == "use_mmap") { return 4; } @@ -1134,6 +1156,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return "fa"; } + if (field == "binary-kq") { + return "bkq"; + } if (field == "use_mmap") { return "mmap"; } @@ -1183,6 +1208,9 @@ struct markdown_printer : public printer { if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { fields.emplace_back("flash_attn"); } + if (params.binary_kq.size() > 1 || params.binary_kq != cmd_params_defaults.binary_kq) { + fields.emplace_back("binary-kq"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e45732aac..39987217d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2074,18 +2074,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y } return max; } -static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { - GGML_ASSERT(n%16 == 0); - __m512 vmax = _mm512_set1_ps(-INFINITY); - __m512 vinf = _mm512_set1_ps(-INFINITY); - const __mmask16 * mm16 = (const __mmask16 *)x; - for (int j = 0; j < n/16; ++j) { - __m512 v = _mm512_mask_blend_ps(mm16[j], _mm512_loadu_ps(y + 16*j), vinf); - _mm512_storeu_ps(y + 16*j, v); - vmax = _mm512_max_ps(vmax, v); - } - return _mm512_reduce_max_ps(vmax); -} #else // TODO static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { @@ -2093,6 +2081,7 @@ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(slope); + GGML_ASSERT(false); return 0.f; } static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { @@ -2100,12 +2089,7 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(slope); - return 0.f; -} -static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { - GGML_UNUSED(n); - GGML_UNUSED(x); - GGML_UNUSED(y); + GGML_ASSERT(false); return 0.f; } #endif @@ -2925,7 +2909,7 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl } } -#ifdef __AVX512__ +#ifdef __AVX512F__ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { const __mmask16 * m16 = (const __mmask16 *)mask; __m512 vinf = _mm512_set1_ps(-INFINITY); @@ -2939,6 +2923,18 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * } return _mm512_reduce_max_ps(vmax); } +static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) { + const __mmask16 * m16 = (const __mmask16 *)mask; + __m512 vinf = _mm512_set1_ps(-INFINITY); + __m512 vmax = vinf; + __m512 vscale = _mm512_set1_ps(scale); + for (int i = 0; i < n/16; ++i) { + __m512 v = _mm512_mask_mul_ps(vinf, m16[i], vscale, _mm512_loadu_ps(x + 16*i)); + _mm512_storeu_ps(y + 16*i, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} //#elif __ARM_NEON //static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { // //const uint16_t * mask16 = (const uint16_t *)mask; @@ -3004,6 +3000,15 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * GGML_ASSERT(false); return 0.f; } +static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(mask); + GGML_UNUSED(scale); + GGML_ASSERT(false); + return 0.f; +} #endif static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { @@ -13952,16 +13957,9 @@ static void ggml_compute_forward_softcap_max_f32( if (use_f16) { ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + mask_row*ne00; max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); - } else if (use_i32) { - int n32 = (ne00 + 31)/32; - const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; - max = ggml_vec_add_f32_infmask(nc, mp_u32, wp); } else { float * mp_f32 = (float *)((char *) src1->data) + mask_row*ne00; max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); - //for (int i = 0; i < nc; ++i) { - // wp[i] += slope*mp_f32[i]; - //} } } else { @@ -14745,6 +14743,7 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32); for (int i1 = ir0; i1 < ir1; i1++) { // ALiBi @@ -14754,33 +14753,54 @@ static void ggml_compute_forward_soft_max_f32( float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; + float max = -INFINITY; + if (use_u32) { + int n32 = ne00/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + (i1%ne01)*n32; + max = ggml_vec_cpy_soft_mask_f32(nc, sp, wp, mp_u32, scale); + } else { + + ggml_vec_cpy_f32 (nc, wp, sp); + ggml_vec_scale_f32(nc, wp, scale); + if (src1) { + // broadcast the mask across rows + if (use_f16) { + ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00; + max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); + } else { + float * mp_f32 = (float *)((char *) src1->data) + (i1%ne01)*ne00; + max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); } } - } + else { + ggml_vec_max_f32(nc, &max, wp); + } + + //// broadcast the mask across rows + //ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + //float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + + //if (mp_f32) { + // if (use_f16) { + // for (int i = 0; i < nc; ++i) { + // wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + // } + // } else { + // for (int i = 0; i < nc; ++i) { + // wp[i] += slope*mp_f32[i]; + // } + // } + //} #ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } #endif - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + ggml_vec_max_f32(nc, &max, wp); + } ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); assert(sum > 0.0); From 97dbc16e86eee9b39d95c96fa11fac57d995d15c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 16:42:49 +0200 Subject: [PATCH 11/13] WIP KQ binary mask --- ggml/src/ggml.c | 57 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 39987217d..9eaf42ff8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2074,23 +2074,54 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y } return max; } +#elif __ARM_NEON +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + float32x4_t vslope = vdupq_n_f32(slope); + float32x4_t vmax = vdupq_n_f32(-INFINITY); + for (int j = 0; j < n/4; ++j) { + float32x4_t val = vmlaq_f32(vld1q_f32(y + 4*j), vslope, vcvt_f32_f16(vld1_f16((const float16_t *)x + 4*j))); + vmax = vmaxq_f32(vmax, val); + vst1q_f32(y + 4*j, val); + } + float max = vmaxvq_f32(vmax); + for (int i = 4*(n/4); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; +} +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + float32x4_t vslope = vdupq_n_f32(slope); + float32x4_t vmax = vdupq_n_f32(-INFINITY); + for (int j = 0; j < n/4; ++j) { + float32x4_t val = vmlaq_f32(vld1q_f32(y + 4*j), vslope, vld1q_f32(x + 4*j)); + vmax = vmaxq_f32(vmax, val); + vst1q_f32(y + 4*j, val); + } + float max = vmaxvq_f32(vmax); + for (int i = 4*(n/4); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; +} #else -// TODO +// TODO add AVX2 static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { - GGML_UNUSED(n); - GGML_UNUSED(x); - GGML_UNUSED(y); - GGML_UNUSED(slope); - GGML_ASSERT(false); - return 0.f; + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + y[i] += slope * GGML_FP16_TO_FP32(x[i]); + max = MAX(max, y[i]); + } + return max; } static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { - GGML_UNUSED(n); - GGML_UNUSED(x); - GGML_UNUSED(y); - GGML_UNUSED(slope); - GGML_ASSERT(false); - return 0.f; + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + y[i] += slope * x[i]; + max = MAX(max, y[i]); + } + return max; } #endif From 4d10f4e0ba09bb9cb9e077e3253e6896ef6ad5d9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 18:14:03 +0300 Subject: [PATCH 12/13] WIP KQ binary mask --- ggml/src/ggml.c | 77 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9eaf42ff8..e1cf168b4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2074,7 +2074,82 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y } return max; } -#elif __ARM_NEON +#elif defined __AVX2__ +static inline float hmax_f32x8(__m256 v) { + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + return _mm_cvtss_f32( max4 ); +} +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + __m256 vmax = _mm256_set1_ps(-INFINITY); + if (fabsf(slope - 1.0f) < 1e-5f) { + for (int j = 0; j < n/8; ++j) { + __m256 vmask = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x + j)); + __m256 v = _mm256_add_ps(vmask, _mm256_loadu_ps(y + 8*j)); + _mm256_storeu_ps(y + 8*j, v); + vmax = _mm256_max_ps(vmax, v); + } + float max = hmax_f32x8(vmax); + for (int i = 8*(n/8); i < n; ++i) { + y[i] += slope*GGML_FP16_TO_FP32(x[i]); + max = MAX(max, y[i]); + } + return max; + } + __m256 vslope = _mm256_set1_ps(slope); + for (int j = 0; j < n/8; ++j) { +#ifdef __FMA__ + __m256 v = _mm256_fmadd_ps(vslope, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x + j)), _mm256_loadu_ps(y + 8*j)); +#else + __m256 vmask = _mm256_mul_ps(vslope, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x + j))); + __m256 v = _mm256_add_ps(vmask, _mm256_loadu_ps(y + 8*j)); +#endif + _mm256_storeu_ps(y + 8*j, v); + vmax = _mm256_max_ps(vmax, v); + } + float max = hmax_f32x8(vmax); + for (int i = 8*(n/8); i < n; ++i) { + y[i] += slope*GGML_FP16_TO_FP32(x[i]); + max = MAX(max, y[i]); + } + return max; +} +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + __m256 vmax = _mm256_set1_ps(-INFINITY); + if (fabsf(slope - 1.0f) < 1e-5f) { + for (int j = 0; j < n/8; ++j) { + __m256 vmask = _mm256_loadu_ps(x + 8*j); + __m256 v = _mm256_add_ps(vmask, _mm256_loadu_ps(y + 8*j)); + _mm256_storeu_ps(y + 8*j, v); + vmax = _mm256_max_ps(vmax, v); + } + float max = hmax_f32x8(vmax); + for (int i = 8*(n/8); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; + } + __m256 vslope = _mm256_set1_ps(slope); + for (int j = 0; j < n/8; ++j) { +#ifdef __FMA__ + __m256 v = _mm256_fmadd_ps(vslope, _mm256_loadu_ps(x + 8*j), _mm256_loadu_ps(y + 8*j)); +#else + __m256 vmask = _mm256_mul_ps(vslope, _mm256_loadu_ps(x + 8*j)); + __m256 v = _mm256_add_ps(vmask, _mm256_loadu_ps(y + 8*j)); +#endif + _mm256_storeu_ps(y + 8*j, v); + vmax = _mm256_max_ps(vmax, v); + } + float max = hmax_f32x8(vmax); + for (int i = 8*(n/8); i < n; ++i) { + y[i] += slope*x[i]; + max = MAX(max, y[i]); + } + return max; +} +#elif defined __ARM_NEON static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { float32x4_t vslope = vdupq_n_f32(slope); float32x4_t vmax = vdupq_n_f32(-INFINITY); From 3b4fe65e1cc1e7f8172cb957dcb01b207df373d8 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 18:31:27 +0300 Subject: [PATCH 13/13] Minor --- ggml/src/ggml.c | 1 - src/llama.cpp | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e1cf168b4..939f52155 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2181,7 +2181,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y return max; } #else -// TODO add AVX2 static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { float max = -INFINITY; for (int i = 0; i < n; ++i) { diff --git a/src/llama.cpp b/src/llama.cpp index 593706002..ad3febef3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8692,8 +8692,9 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { auto nx = causal ? n_kv : n_tokens; // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy + //bool can_be_binary = binary_kq && !lctx.is_encoding && !flash_attn && !hparams.use_alibi && nx%32 == 0; + //auto type = can_be_binary ? GGML_TYPE_I32 : flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; auto type = !lctx.is_encoding ? !binary_kq || flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; - //auto type = flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32; if (type == GGML_TYPE_I32) nx /= 32; lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask, "KQ_mask", -1); @@ -8705,6 +8706,8 @@ struct llm_build_context { GGML_ASSERT(hparams.n_swa > 0); auto nx = causal ? n_kv : n_tokens; // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy + //bool can_be_binary = binary_kq && !lctx.is_encoding && !flash_attn && !hparams.use_alibi && nx%32 == 0; + //auto type = can_be_binary ? GGML_TYPE_I32 : flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; auto type = !lctx.is_encoding ? !binary_kq || flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; if (type == GGML_TYPE_I32) nx /= 32; lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));