@@ -630,7 +630,7 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll
630630 {
631631 const llama_vocab * tmpvocab = llama_model_get_vocab (draftmodel);
632632 int draftvocab = llama_vocab_n_tokens (tmpvocab);
633- if (llama_model_is_recurrent (draftmodel))
633+ if (llama_model_is_recurrent (draftmodel) || llama_model_is_hybrid (draftmodel) )
634634 {
635635 printf (" Error: Speculative decoding cannot be used with Recurrent draft models!\n " );
636636 llama_free (draft_ctx);
@@ -2523,7 +2523,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
25232523
25242524 if (draftmodel_filename !=" " && file_format==FileFormat::GGUF_GENERIC)
25252525 {
2526- if (llama_model_is_recurrent (llamamodel))
2526+ if (llama_model_is_recurrent (llamamodel) || llama_model_is_hybrid (llamamodel) )
25272527 {
25282528 printf (" Error: Speculative decoding cannot be used with Recurrent models!\n " );
25292529 }
@@ -3758,7 +3758,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
37583758 if (file_format==FileFormat::GGUF_GENERIC)
37593759 {
37603760 const llama_model * mdl = llama_get_model (llama_ctx_v4);
3761- if (llama_model_is_recurrent (mdl) || llama_model_is_hybrid (mdl))
3761+ if (llama_model_is_recurrent (mdl) || llama_model_is_hybrid (mdl) || file_format_meta. model_architecture ==GGUFArch::ARCH_MAMBALIKE || file_format_meta. model_architecture ==GGUFArch::ARCH_RWKV )
37623762 {
37633763 is_recurrent = true ;
37643764 }
@@ -3789,6 +3789,22 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
37893789 embd_inp.push_back (current_context_tokens[current_context_tokens.size ()-1 ]);
37903790 n_past -= 1 ;
37913791 }
3792+ else if (embd_inp.size ()>0 && current_context_tokens.size ()>0 && last_n_tokens.size ()>0 )
3793+ {
3794+ int maxedpos = llama_memory_seq_pos_max (llama_get_memory (llama_ctx_v4),0 );
3795+ if (maxedpos+2 ==n_past)
3796+ {
3797+ // kcpp: a very dirty hack for rnn models. this happens because the very last token of the last turn
3798+ // does not actually get processed but is still added to current_context_tokens. if the instruct start tag starts with that same token
3799+ // it might get wrongly fast forwarded and we will get an off by 1 error.
3800+ // todo: figure out a better way to solve this rubbish
3801+ int tail = last_n_tokens[last_n_tokens.size ()-1 ];
3802+ last_n_tokens.pop_back ();
3803+ current_context_tokens.pop_back ();
3804+ n_past -=1 ;
3805+ embd_inp.insert (embd_inp.begin (), 1 , tail);
3806+ }
3807+ }
37923808 }
37933809 }
37943810 else
0 commit comments