Skip to content

Commit e92f9fd

Browse files
committed
cursed hack for RNN models
1 parent 0cc0ea4 commit e92f9fd

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

gpttype_adapter.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll
630630
{
631631
const llama_vocab * tmpvocab = llama_model_get_vocab(draftmodel);
632632
int draftvocab = llama_vocab_n_tokens(tmpvocab);
633-
if(llama_model_is_recurrent(draftmodel))
633+
if(llama_model_is_recurrent(draftmodel) || llama_model_is_hybrid(draftmodel))
634634
{
635635
printf("Error: Speculative decoding cannot be used with Recurrent draft models!\n");
636636
llama_free(draft_ctx);
@@ -2523,7 +2523,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
25232523

25242524
if(draftmodel_filename !="" && file_format==FileFormat::GGUF_GENERIC)
25252525
{
2526-
if(llama_model_is_recurrent(llamamodel))
2526+
if(llama_model_is_recurrent(llamamodel) || llama_model_is_hybrid(llamamodel))
25272527
{
25282528
printf("Error: Speculative decoding cannot be used with Recurrent models!\n");
25292529
}
@@ -3758,7 +3758,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
37583758
if(file_format==FileFormat::GGUF_GENERIC)
37593759
{
37603760
const llama_model * mdl = llama_get_model(llama_ctx_v4);
3761-
if(llama_model_is_recurrent(mdl) || llama_model_is_hybrid(mdl))
3761+
if(llama_model_is_recurrent(mdl) || llama_model_is_hybrid(mdl) || file_format_meta.model_architecture==GGUFArch::ARCH_MAMBALIKE || file_format_meta.model_architecture==GGUFArch::ARCH_RWKV)
37623762
{
37633763
is_recurrent = true;
37643764
}
@@ -3789,6 +3789,22 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
37893789
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
37903790
n_past -= 1;
37913791
}
3792+
else if(embd_inp.size()>0 && current_context_tokens.size()>0 && last_n_tokens.size()>0)
3793+
{
3794+
int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0);
3795+
if(maxedpos+2==n_past)
3796+
{
3797+
//kcpp: a very dirty hack for rnn models. this happens because the very last token of the last turn
3798+
//does not actually get processed but is still added to current_context_tokens. if the instruct start tag starts with that same token
3799+
//it might get wrongly fast forwarded and we will get an off by 1 error.
3800+
//todo: figure out a better way to solve this rubbish
3801+
int tail = last_n_tokens[last_n_tokens.size()-1];
3802+
last_n_tokens.pop_back();
3803+
current_context_tokens.pop_back();
3804+
n_past -=1;
3805+
embd_inp.insert(embd_inp.begin(), 1, tail);
3806+
}
3807+
}
37923808
}
37933809
}
37943810
else

0 commit comments

Comments
 (0)