Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 168 additions & 26 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4250,9 +4250,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen2_5OmniModel")
class Qwen25OmniModel(Qwen2VLVisionModel):
has_vision_encoder = True
class Qwen25AudioModel(MmprojModel):
has_audio_encoder = True

def __init__(self, *args, **kwargs):
Expand All @@ -4268,12 +4266,6 @@ def set_gguf_parameters(self):
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))

def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("vision_config")

def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("audio_config")

def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# SinusoidsPositionEmbedding
assert self.hparams_audio is not None
Expand Down Expand Up @@ -4303,8 +4295,33 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if "audio_bos_eos_token" in name:
# this tensor is left unused in transformers code
# https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809
return
yield from super().modify_tensors(data_torch, name, bid)
return []
return [(self.map_tensor_name(name), data_torch)]

return [] # skip other tensors
Comment on lines +4298 to +4301
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return []
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors
return
yield from super().modify_tensors(data_torch, name, bid)
return # skip other tensors



@ModelBase.register("Qwen2_5OmniModel")
class Qwen25OmniModel(Qwen2VLVisionModel, Qwen25AudioModel):
has_audio_encoder = True
has_vision_encoder = True

def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("vision_config")

def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("audio_config")

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "visual." in name:
yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid)
elif "audio_tower." in name:
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
return [] # skip other tensors
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return [] # skip other tensors
return # skip other tensors



@ModelBase.register("InternVisionModel")
Expand Down Expand Up @@ -4808,7 +4825,10 @@ def set_gguf_parameters(self):
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
if self.hparams_vision is None:
logger.info("No vision config found, skipping vision tensor processing")
return

# Compute image_size if not present
if "image_size" not in self.hparams_vision:
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
Expand All @@ -4829,7 +4849,9 @@ def __init__(self, *args, **kwargs):

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
# in case mixed modalities, the arch will be handled by subclass
if not self.has_audio_encoder:
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
self.gguf_writer.add_vision_use_gelu(True)

if self.hparams_vision is not None:
Expand Down Expand Up @@ -4917,11 +4939,64 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return

if name.startswith("visual."):
yield from super().modify_tensors(data_torch, name, bid)
return
yield (self.map_tensor_name(name), data_torch)
return [] # skip other tensors
Comment on lines +4942 to +4943
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
yield (self.map_tensor_name(name), data_torch)
return [] # skip other tensors
yield from super().modify_tensors(data_torch, name, bid)
return # skip other tensors


# Fall back to parent class for other tensors
yield from super().modify_tensors(data_torch, name, bid)

@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel):
has_audio_encoder = True
has_vision_encoder = True

def get_vision_config(self) -> dict[str, Any] | None:
if self.has_vision_encoder:
return self.global_config["thinker_config"].get("vision_config")
else:
return None

def get_audio_config(self) -> dict[str, Any] | None:
if self.has_audio_encoder:
return self.global_config["thinker_config"].get("audio_config")
else:
return None

def set_gguf_parameters(self):
if self.has_vision_encoder:
Qwen3VLVisionModel.set_gguf_parameters(self)
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL)
if self.has_audio_encoder:
Qwen25AudioModel.set_gguf_parameters(self)
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "visual." in name:
if not self.has_vision_encoder:
raise ValueError(f"Model does not have vision encoder, but found tensor {name}")
# need to transform vision tensor naming, so that modify_tensors() logic can be used correctly
name = name.replace("thinker.visual.", "model.visual.")
if ".merger_list." in name:
name = name.replace(".merger_list.", ".deepstack_merger_list.")
name = name.replace(".ln_q", ".norm")
name = name.replace(".mlp.0", ".linear_fc1")
name = name.replace(".mlp.2", ".linear_fc2")
elif ".merger." in name:
name = name.replace(".ln_q", ".norm")
name = name.replace(".mlp.0", ".linear_fc1")
name = name.replace(".mlp.2", ".linear_fc2")
yield from Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid)
elif "audio_tower." in name:
if not self.has_audio_encoder:
raise ValueError(f"Model does not have audio encoder, but found tensor {name}")
if "conv2d" in name and name.endswith(".bias"):
# transform conv2d bias [n_embd] --> [1, 1, n_embd]
data_torch = data_torch.unsqueeze(-1).unsqueeze(-1)
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)


@ModelBase.register("Qwen3ASRForConditionalGeneration")
class Qwen3ASRMmprojModel(Qwen3OmniMmprojModel):
has_audio_encoder = True
has_vision_encoder = False


@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
Expand Down Expand Up @@ -4955,9 +5030,10 @@ class Qwen3VLTextModel(Qwen3Model):

def set_gguf_parameters(self):
super().set_gguf_parameters()

# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
vision_config = self.hparams.get("vision_config", {})
if "thinker_config" in self.hparams:
vision_config = self.hparams["thinker_config"].get("vision_config", {})
else:
vision_config = self.hparams.get("vision_config", {})
Comment on lines +5033 to +5036
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of handling this everywhere, can't we just merge in all sub-configs in thinker_config here:

if "thinker_config" in config:
# rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"]

deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)

Expand All @@ -4969,20 +5045,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
@ModelBase.register("Qwen3ASRForConditionalGeneration")
class Qwen3ASRTextModel(Qwen3VLTextModel):
model_arch = gguf.MODEL_ARCH.QWEN3VL

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_num_deepstack_layers(0)

def set_vocab(self):
super().set_vocab()
# fix chat template, use correct chatml format
self.gguf_writer.add_chat_template("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}")
# correct BOS/EOS tokens
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
for token_id, data in added_tokens.items():
if data.get("content") == "<|im_end|>":
self.gguf_writer.add_bos_token_id(int(token_id))
self.gguf_writer.add_eos_token_id(int(token_id))

def modify_tensors(self, data_torch, name, bid):
# qwen3-omni
name = name.replace("thinker.", "")

# Skip vision and audio tensors - they go in the mmproj file
if "visual." in name or "audio_tower." in name \
or "talker." in name or "code2wav." in name:
return

yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3VLMoeForConditionalGeneration", "Qwen3OmniMoeForConditionalGeneration")
class Qwen3VLMoeTextModel(Qwen3MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE

def set_vocab(self):
super().set_vocab()
# correct BOS/EOS tokens
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
for token_id, data in added_tokens.items():
if data.get("content") == "<|im_end|>":
self.gguf_writer.add_bos_token_id(int(token_id))
self.gguf_writer.add_eos_token_id(int(token_id))

def set_gguf_parameters(self):
super().set_gguf_parameters()
vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
self.gguf_writer.add_num_deepstack_layers(0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors - they go in the mmproj file
if name.startswith("model.visual."):
return
if "visual." in name or "audio_tower." in name \
or "talker." in name or "code2wav." in name:
return []
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return []
return


# qwen3-omni
name = name.replace("thinker.", "")

# Qwen3VL has transposed packed tensors, so we treat it differently from general Qwen2MoE packed tensors
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
Expand Down Expand Up @@ -5016,6 +5138,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
class Qwen3OmniMoeTextModel(Qwen3VLMoeTextModel):
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE

def set_vocab(self):
super().set_vocab()
# correct BOS/EOS tokens
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
for token_id, data in added_tokens.items():
if data.get("content") == "<|im_end|>":
self.gguf_writer.add_bos_token_id(int(token_id))
self.gguf_writer.add_eos_token_id(int(token_id))

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_num_deepstack_layers(0)


class _LinearAttentionVReorderBase(Qwen3NextModel):
model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses
"""reorders V heads from grouped to tiled order for ggml broadcast
Expand Down
7 changes: 7 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,8 @@ class MODEL_TENSOR(IntEnum):
A_ENC_EMBD_TO_LOGITS = auto() # lfm2
A_ENC_CONV1D = auto()
A_ENC_CONV1D_NORM = auto() # gemma3n
A_ENC_CONV2D = auto()
A_ENC_CONV_OUT = auto()
A_PRE_NORM = auto()
A_POST_NORM = auto()
A_ENC_LAYER_PRE_NORM = auto() # gemma3n
Expand Down Expand Up @@ -1244,6 +1246,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm",
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
MODEL_TENSOR.A_ENC_CONV2D: "a.conv2d.{bid}",
MODEL_TENSOR.A_ENC_CONV_OUT: "a.conv_out",
MODEL_TENSOR.A_ENC_CONV1D_NORM: "a.conv1d.{bid}.norm",
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
Expand Down Expand Up @@ -1376,6 +1380,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.A_ENC_EMBD_NORM,
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
MODEL_TENSOR.A_ENC_CONV1D,
MODEL_TENSOR.A_ENC_CONV2D,
MODEL_TENSOR.A_ENC_CONV_OUT,
MODEL_TENSOR.A_ENC_CONV1D_NORM,
MODEL_TENSOR.A_PRE_NORM,
MODEL_TENSOR.A_POST_NORM,
Expand Down Expand Up @@ -4020,6 +4026,7 @@ class VisionProjectorType:
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
QWEN2A = "qwen2a" # audio
QWEN3A = "qwen3a" # audio
GLMA = "glma" # audio
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
Expand Down
11 changes: 10 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,14 @@ class TensorNameMap:
"model.audio_tower.subsample_conv_projection.conv_{bid}.norm", # gemma3n
),

MODEL_TENSOR.A_ENC_CONV2D: (
"audio_tower.conv2d{bid}", # qwen3omni
),

MODEL_TENSOR.A_ENC_CONV_OUT: (
"audio_tower.conv_out", # qwen3omni
),

MODEL_TENSOR.A_PRE_NORM: (),

MODEL_TENSOR.A_POST_NORM: (
Expand Down Expand Up @@ -1912,7 +1920,8 @@ class TensorNameMap:

MODEL_TENSOR.A_MMPROJ: (
"audio.multi_modal_projector.linear_{bid}", # ultravox
"audio_adapter.model.{bid}" # lfm2
"audio_adapter.model.{bid}", # lfm2
"audio_tower.proj{bid}", # qwen3omni
),

MODEL_TENSOR.A_MMPROJ_FC: (
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_library(mtmd
models/pixtral.cpp
models/qwen2vl.cpp
models/qwen3vl.cpp
models/qwen3a.cpp
models/siglip.cpp
models/whisper-enc.cpp
models/deepseekocr.cpp
Expand Down
4 changes: 4 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@

// ultravox
#define TN_CONV1D "a.conv1d.%d.%s"
#define TN_CONV2D "a.conv2d.%d.%s"
#define TN_CONV_OUT "a.conv_out.%s"
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
Expand Down Expand Up @@ -241,6 +243,7 @@ enum projector_type {
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_QWEN3A,
PROJECTOR_TYPE_GLMA,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
Expand Down Expand Up @@ -279,6 +282,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_QWEN3A, "qwen3a"},
{ PROJECTOR_TYPE_GLMA, "glma"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
Expand Down
10 changes: 10 additions & 0 deletions tools/mtmd/clip-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,20 @@ struct clip_model {
ggml_tensor * conv1d_1_b = nullptr;
ggml_tensor * conv1d_2_w = nullptr;
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * conv_out_w = nullptr;
ggml_tensor * conv_out_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_pre_b = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;

// qwen3a
ggml_tensor * conv2d_1_w = nullptr;
ggml_tensor * conv2d_1_b = nullptr;
ggml_tensor * conv2d_2_w = nullptr;
ggml_tensor * conv2d_2_b = nullptr;
ggml_tensor * conv2d_3_w = nullptr;
ggml_tensor * conv2d_3_b = nullptr;

// cogvlm
ggml_tensor * mm_post_fc_norm_w = nullptr;
ggml_tensor * mm_post_fc_norm_b = nullptr;
Expand Down
Loading
Loading