Skip to content

Commit 7d316fa

Browse files
committed
feat: Add llama_model_is_hybrid API call
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent d4e0d95 commit 7d316fa

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,9 @@ extern "C" {
569569
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
570570
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
571571

572+
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
573+
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
574+
572575
// Returns 0 on success
573576
LLAMA_API uint32_t llama_model_quantize(
574577
const char * fname_inp,

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,3 +1752,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
17521752
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
17531753
return LLM_TENSOR_INFOS.at(tensor);
17541754
}
1755+
1756+
bool llm_arch_is_recurrent(const llm_arch & arch) {
1757+
switch (arch) {
1758+
case LLM_ARCH_MAMBA:
1759+
case LLM_ARCH_RWKV6:
1760+
case LLM_ARCH_RWKV6QWEN2:
1761+
case LLM_ARCH_RWKV7:
1762+
case LLM_ARCH_ARWKV7:
1763+
return true;
1764+
default:
1765+
return false;
1766+
}
1767+
}
1768+
1769+
bool llm_arch_is_hybrid(const llm_arch & arch) {
1770+
// TODO: There are currently no hybrid models! Once there are, this will be
1771+
// the place to identify them
1772+
switch (arch) {
1773+
default:
1774+
return false;
1775+
}
1776+
}

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,6 @@ const char * llm_arch_name(llm_arch arch);
436436
llm_arch llm_arch_from_string(const std::string & name);
437437

438438
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
439+
440+
bool llm_arch_is_recurrent(const llm_arch& arch);
441+
bool llm_arch_is_hybrid(const llm_arch& arch);

src/llama-model.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13821,14 +13821,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
1382113821
}
1382213822

1382313823
bool llama_model_is_recurrent(const llama_model * model) {
13824-
switch (model->arch) {
13825-
case LLM_ARCH_MAMBA: return true;
13826-
case LLM_ARCH_RWKV6: return true;
13827-
case LLM_ARCH_RWKV6QWEN2: return true;
13828-
case LLM_ARCH_RWKV7: return true;
13829-
case LLM_ARCH_ARWKV7: return true;
13830-
default: return false;
13831-
}
13824+
return llm_arch_is_recurrent(model->arch);
13825+
}
13826+
13827+
bool llama_model_is_hybrid(const llama_model * model) {
13828+
return llm_arch_is_hybrid(model->arch);
1383213829
}
1383313830

1383413831
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {

0 commit comments

Comments
 (0)