Skip to content
62 changes: 62 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9219,6 +9219,68 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [] # skip other tensors


@ModelBase.register("Glm4vMoeForConditionalGeneration")
class GLM4VMoEModel(Glm4MoeModel):
"""Text model from [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V)

ref: [#16600](https://github.com/ggml-org/llama.cpp/pull/16600)"""
model_arch = gguf.MODEL_ARCH.GLM4_MOE

def set_gguf_parameters(self):
# parameters specific to GLM-4.5V like rope_theta=10000 and context_length=65536
# should be correctly picked up from the text_config by the base classes
super().set_gguf_parameters()

def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors for the text model
if name.startswith("model.visual."):
return []

# the Glm4MoeModel class expects tensor names to start with 'model.',
# so we strip the we strip the 'language_model.' part
if name.startswith("model.language_model."):
name = name.replace("model.language_model.", "model.", 1)

# let the parent class handle the MoE logic and tensor mapping
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Glm4vMoeForConditionalGeneration")
class GLM4VMoEVisionModel(MmprojModel):
"""Multimodal projector from [zai-org/GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V).

ref: [#16600](https://github.com/ggml-org/llama.cpp/pull/16600)"""
#
# TODO: this is not complete yet!
#
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V)
self.gguf_writer.add_vision_use_gelu(True)

if (ln_eps := self.find_vparam(["layer_norm_eps"], optional=True)) is not None:
self.gguf_writer.add_vision_attention_layernorm_eps(ln_eps)

# the ViT in GLM-4.5V applies its own RoPE inside its attention blocks
if (rope_theta := self.find_vparam(["rope_theta"], optional=True)) is not None:
self.gguf_writer.add_vision_rope_freq_base(rope_theta)
logger.info(f"gguf: vision rope theta = {rope_theta}")
else:
logger.warning('gguf: -------------------------------------------------------------')
logger.warning('gguf: missing vision rope theta! the conversion might be incorrect!')
logger.warning('gguf: -------------------------------------------------------------')

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("model.visual."):
yield self.map_tensor_name(name), data_torch
else:
return


###### CONVERSION LOGIC ######


Expand Down
51 changes: 47 additions & 4 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,21 @@ class ClipVision:
USE_SILU = "clip.use_silu"
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl

class Rope:
DIMENSION_COUNT = "clip.vision.rope.dimension_count"
DIMENSION_SECTIONS = "clip.vision.rope.dimension_sections"
FREQ_BASE = "clip.vision.rope.freq_base"
SCALING_TYPE = "clip.vision.rope.scaling.type"
SCALING_FACTOR = "clip.vision.rope.scaling.factor"
SCALING_ATTN_FACTOR = "clip.vision.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "clip.vision.rope.scaling.original_context_length"
SCALING_FINETUNED = "clip.vision.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "clip.vision.rope.scaling.yarn_log_multiplier"
SCALING_YARN_EXT_FACTOR = "clip.vision.rope.scaling.yarn_ext_factor"
SCALING_YARN_ATTN_FACTOR = "clip.vision.rope.scaling.yarn_attn_factor"
SCALING_YARN_BETA_FAST = "clip.vision.rope.scaling.yarn_beta_fast"
SCALING_YARN_BETA_SLOW = "clip.vision.rope.scaling.yarn_beta_slow"

class Attention:
HEAD_COUNT = "clip.vision.attention.head_count"
LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon"
Expand Down Expand Up @@ -385,6 +400,7 @@ class MODEL_ARCH(IntEnum):
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
GLM4V_MOE = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
Expand Down Expand Up @@ -427,6 +443,7 @@ class VISION_PROJECTOR_TYPE(IntEnum):
GLM_EDGE = auto()
MERGER = auto()
GEMMA3 = auto()
GLM4V = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -656,10 +673,10 @@ class MODEL_TENSOR(IntEnum):
A_MM_NORM_PRE = auto()
A_MM_NORM_MID = auto()
# nextn/mtp
NEXTN_EH_PROJ = auto()
NEXTN_EMBED_TOKENS = auto()
NEXTN_ENORM = auto()
NEXTN_HNORM = auto()
NEXTN_EH_PROJ = auto()
NEXTN_EMBED_TOKENS = auto()
NEXTN_ENORM = auto()
NEXTN_HNORM = auto()
NEXTN_SHARED_HEAD_HEAD = auto()
NEXTN_SHARED_HEAD_NORM = auto()

Expand Down Expand Up @@ -729,6 +746,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
MODEL_ARCH.GLM4V_MOE: "glm4v_moe",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
Expand Down Expand Up @@ -2273,6 +2291,30 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.GLM4V_MOE: [ # same as GLM4_MOE without MTP tensors
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
Expand Down Expand Up @@ -3029,6 +3071,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
GLM4V = "glm4v_moe"


# Items here are (block size, type size)
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,9 @@ def add_vision_head_count(self, value: int) -> None:
def add_vision_attention_layernorm_eps(self, value: float) -> None:
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)

def add_vision_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.ClipVision.Rope.FREQ_BASE, value)

def add_vision_image_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)

Expand Down
28 changes: 28 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
{ LLM_ARCH_GLM4V_MOE, "glm4v_moe" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
Expand Down Expand Up @@ -1507,6 +1508,33 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
},
},
{
LLM_ARCH_GLM4V_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_BITNET,
{
Expand Down
2 changes: 1 addition & 1 deletion src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ enum llm_arch {
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
LLM_ARCH_GLM4V_MOE,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
Expand Down Expand Up @@ -123,7 +124,6 @@ enum llm_kv {
LLM_KV_GENERAL_LICENSE,
LLM_KV_GENERAL_SOURCE_URL,
LLM_KV_GENERAL_SOURCE_HF_REPO,

LLM_KV_VOCAB_SIZE,
LLM_KV_CONTEXT_LENGTH,
LLM_KV_EMBEDDING_LENGTH,
Expand Down
4 changes: 2 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ ggml_tensor * llm_graph_context::build_ffn(

if (down) {
cur = build_lora_mm(down, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_GLM4V_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
Expand Down Expand Up @@ -1583,7 +1583,7 @@ ggml_tensor * llm_graph_context::build_attn(

if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_GLM4V_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
Expand Down
20 changes: 20 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GLM4V_MOE:
{
// TODO
} break;
case LLM_ARCH_BITNET:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
Expand Down Expand Up @@ -4892,6 +4896,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
break;
case LLM_ARCH_GLM4V_MOE:
{
// TODO
}
break;
case LLM_ARCH_NEMOTRON:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
Expand Down Expand Up @@ -14683,6 +14692,12 @@ struct llm_build_glm4_moe : public llm_graph_context {
}
};

struct llm_build_glm4v_moe : public llm_graph_context {
llm_build_glm4v_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
// TODO
}
};

struct llm_build_nemotron : public llm_graph_context {
llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
Expand Down Expand Up @@ -19750,6 +19765,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
} break;
case LLM_ARCH_GLM4V_MOE:
{
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
} break;
case LLM_ARCH_BITNET:
{
llm = std::make_unique<llm_build_bitnet>(*this, params);
Expand Down Expand Up @@ -20119,6 +20138,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
return LLAMA_ROPE_TYPE_NEOX;

case LLM_ARCH_QWEN2VL:
case LLM_ARCH_GLM4V_MOE:
return LLAMA_ROPE_TYPE_MROPE;

// all model arches should be listed explicitly here
Expand Down
2 changes: 1 addition & 1 deletion src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ enum llm_type {
LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_106B_A12B, // GLM-4.5-Air (and GLM-4.5V text model)
LLM_TYPE_235B_A22B,
LLM_TYPE_300B_A47B, // Ernie MoE big
LLM_TYPE_355B_A32B, // GLM-4.5
Expand Down