Skip to content

Commit 7c0f4ca

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 21354e8 + 41bdd86 commit 7c0f4ca

File tree

4 files changed

+283
-26
lines changed

4 files changed

+283
-26
lines changed

examples/server/server.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1968,7 +1968,9 @@ struct server_context {
19681968
slot.generated_text.erase(
19691969
slot.generated_text.begin() + pos + stop_pos,
19701970
slot.generated_text.end());
1971-
pos = std::min(slot.n_sent_text, slot.generated_text.size());
1971+
// Update n_sent_text to not exceed the new generated_text size
1972+
slot.n_sent_text = std::min(slot.n_sent_text, slot.generated_text.size());
1973+
pos = slot.n_sent_text;
19721974
} else {
19731975
is_stop_full = false;
19741976
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
@@ -1980,6 +1982,9 @@ struct server_context {
19801982
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
19811983
slot.n_sent_text += result.text_to_send.size();
19821984
// add the token to slot queue and cache
1985+
} else if (stop_pos != std::string::npos) {
1986+
// Handle partial stop - update n_sent_text to the end of the current text
1987+
slot.n_sent_text = slot.generated_text.size();
19831988
}
19841989

19851990
slot.add_token_string(result);

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2235,7 +2235,7 @@ extern "C" {
22352235
int min_entries,
22362236
float thresh);
22372237

2238-
#define GGML_KQ_MASK_PAD 64
2238+
#define GGML_KQ_MASK_PAD 16
22392239

22402240
// q: [n_embd, n_batch, n_head, 1]
22412241
// k: [n_embd, n_kv, n_head_kv, 1]

src/llama-build-context.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,12 @@ ggml_tensor * llm_build_context::build_inp_out_ids() {
276276
}
277277

278278
ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
279+
if (causal && flash_attn) {
280+
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
281+
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
282+
ggml_set_input(lctx.inp_KQ_mask);
283+
return lctx.inp_KQ_mask;
284+
}
279285
lctx.inp_KQ_mask = causal
280286
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
281287
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
@@ -287,6 +293,12 @@ ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
287293

288294
ggml_tensor * llm_build_context::build_inp_KQ_mask_swa(bool causal) {
289295
GGML_ASSERT(hparams.n_swa > 0);
296+
if (causal && flash_attn) {
297+
lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
298+
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
299+
ggml_set_input(lctx.inp_KQ_mask_swa);
300+
return lctx.inp_KQ_mask_swa;
301+
}
290302

291303
lctx.inp_KQ_mask_swa = causal
292304
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))

0 commit comments

Comments
 (0)