Skip to content

Commit 5d89a48

Browse files
committed
add more rnn models supported
1 parent c7a1eec commit 5d89a48

File tree

3 files changed

+8
-13
lines changed

3 files changed

+8
-13
lines changed

gpttype_adapter.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,8 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tok
484484
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
485485
return;
486486
}
487-
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA
488-
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV || file_format_meta.model_architecture==GGUFArch::ARCH_JAMBA));
487+
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBALIKE
488+
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV));
489489
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)
490490
{
491491
printf("\nWARNING: RNN models do not support context rewind!\n");
@@ -3747,8 +3747,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
37473747
printf("%s\n", RemoveBell(outstr).c_str());
37483748
}
37493749

3750-
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA
3751-
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV || file_format_meta.model_architecture==GGUFArch::ARCH_JAMBA));
3750+
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBALIKE
3751+
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV));
37523752
bool blank_prompt = (addedmemory=="" && kcpp_data->prompt=="");
37533753

37543754
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)

model_adapter.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,13 +367,9 @@ std::string gguf_get_model_arch(const std::string & gguf_filename)
367367
{
368368
fileformatmeta->model_architecture = GGUFArch::ARCH_FALCON;
369369
}
370-
else if(modelarch=="mamba")
370+
else if(modelarch=="mamba" || modelarch=="mamba2" || modelarch=="nemotron_h" || modelarch=="jamba") //lazy approach, put all RNN models
371371
{
372-
fileformatmeta->model_architecture = GGUFArch::ARCH_MAMBA;
373-
}
374-
else if(modelarch=="jamba")
375-
{
376-
fileformatmeta->model_architecture = GGUFArch::ARCH_JAMBA;
372+
fileformatmeta->model_architecture = GGUFArch::ARCH_MAMBALIKE;
377373
}
378374
else if(modelarch=="llama" && freq_base_train==10000.0f && (n_tensors==435 || n_tensors==611))
379375
{

model_adapter.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,15 @@ enum GGUFArch
5555
ARCH_DEFAULT = 0, //used for llama3 and other generic gguf
5656
ARCH_FALCON = 1,
5757
ARCH_PHI = 2,
58-
ARCH_MAMBA = 3,
58+
ARCH_MAMBALIKE = 3,
5959
ARCH_SOLAR = 4,
6060
ARCH_QWEN2 = 5,
6161
ARCH_RWKV = 6,
6262
ARCH_QWEN2VL = 7,
6363
ARCH_GEMMA3 = 8,
6464
ARCH_GLM4 = 9,
6565
ARCH_GEMMA3N = 10,
66-
ARCH_JAMBA = 11,
67-
ARCH_GPTOSS = 12,
66+
ARCH_GPTOSS = 11,
6867
};
6968

7069
struct FileFormatExtraMeta

0 commit comments

Comments
 (0)