Skip to content

Commit c44d55a

Browse files
committed
merge
2 parents b66c2fd + b52edd2 commit c44d55a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2512
-483
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3203,7 +3203,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32033203
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
32043204
add_opt(common_arg(
32053205
{"--parse-special"},
3206-
string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
3206+
string_format("parse special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
32073207
[](common_params & params) {
32083208
params.parse_special = true;
32093209
}

convert_hf_to_gguf.py

Lines changed: 250 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,7 +1531,7 @@ def set_gguf_parameters(self):
15311531
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
15321532
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
15331533
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
1534-
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
1534+
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
15351535

15361536
# preprocessor config
15371537
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
@@ -3855,7 +3855,43 @@ def set_gguf_parameters(self):
38553855
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
38563856
# process the experts separately
38573857
name = name.replace("language_model.", "") # InternVL
3858-
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
3858+
3859+
# handle aggregated expert tensors
3860+
# GGUF stores dimensions reversed from PyTorch, so:
3861+
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
3862+
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
3863+
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
3864+
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
3865+
mapped = f"{name}.weight" if not name.endswith(".weight") else name
3866+
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
3867+
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
3868+
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
3869+
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
3870+
permuted = data_torch.permute(0, 2, 1).contiguous()
3871+
return [(self.map_tensor_name(mapped), permuted)]
3872+
3873+
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
3874+
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
3875+
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
3876+
split_dim = data_torch.shape[-1] // 2
3877+
gate = data_torch[..., :split_dim].contiguous()
3878+
up = data_torch[..., split_dim:].contiguous()
3879+
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
3880+
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
3881+
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
3882+
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
3883+
base_name = name.removesuffix(".weight")
3884+
base = base_name.rsplit('.', 1)[0]
3885+
mapped_gate = f"{base}.gate_proj.weight"
3886+
mapped_up = f"{base}.up_proj.weight"
3887+
perm_gate = gate.permute(0, 2, 1).contiguous()
3888+
perm_up = up.permute(0, 2, 1).contiguous()
3889+
return [
3890+
(self.map_tensor_name(mapped_gate), perm_gate),
3891+
(self.map_tensor_name(mapped_up), perm_up),
3892+
]
3893+
3894+
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
38593895
# skip visual tensors
38603896
return []
38613897
if name.find("experts") != -1:
@@ -4007,6 +4043,187 @@ def set_vocab(self):
40074043
super().set_vocab()
40084044

40094045

4046+
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
4047+
class Qwen3VLVisionModel(MmprojModel):
4048+
def __init__(self, *args, **kwargs):
4049+
super().__init__(*args, **kwargs)
4050+
assert self.hparams_vision is not None
4051+
# Compute image_size if not present
4052+
if "image_size" not in self.hparams_vision:
4053+
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
4054+
num_pos = self.hparams_vision.get("num_position_embeddings", 2304)
4055+
patch_size = self.hparams_vision.get("patch_size", 16)
4056+
# num_position_embeddings = (image_size / patch_size) ** 2
4057+
# So image_size = sqrt(num_position_embeddings) * patch_size
4058+
image_size = int(num_pos**0.5 * patch_size)
4059+
self.hparams_vision["image_size"] = image_size
4060+
4061+
# Rename config values for compatibility
4062+
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
4063+
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
4064+
4065+
self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0)
4066+
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
4067+
self.is_deepstack_layers[idx] = True
4068+
4069+
def set_gguf_parameters(self):
4070+
super().set_gguf_parameters()
4071+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
4072+
self.gguf_writer.add_vision_use_gelu(True)
4073+
4074+
if self.hparams_vision is not None:
4075+
merge_size = self.hparams_vision.get("spatial_merge_size")
4076+
if merge_size is not None:
4077+
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))
4078+
4079+
# Use text config's rms_norm_eps for vision attention layernorm eps
4080+
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
4081+
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
4082+
4083+
if self.is_deepstack_layers:
4084+
self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)
4085+
4086+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4087+
assert self.hparams_vision is not None
4088+
# Skip text model tensors - they go in the text model file
4089+
if name.startswith("model.language_model.") or name.startswith("lm_head."):
4090+
return []
4091+
4092+
if name.startswith("model.visual."):
4093+
name = name.replace("model.visual.", "visual.", 1)
4094+
4095+
if name.startswith("visual.deepstack_merger_list."):
4096+
prefix, rest = name.split(".", maxsplit=3)[2:]
4097+
# prefix is the layer index, convert to absolute clip layer index!
4098+
idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)]
4099+
target = rest
4100+
4101+
tensor_type: gguf.MODEL_TENSOR
4102+
if target.startswith("norm."):
4103+
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
4104+
suffix = target.split(".", 1)[1]
4105+
elif target.startswith("linear_fc1."):
4106+
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
4107+
suffix = target.split(".", 1)[1]
4108+
elif target.startswith("linear_fc2."):
4109+
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
4110+
suffix = target.split(".", 1)[1]
4111+
else:
4112+
raise ValueError(f"Unexpected deepstack tensor: {name}")
4113+
4114+
new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}")
4115+
return [(new_name, data_torch)]
4116+
4117+
if name.startswith("visual.merger."):
4118+
suffix = name.split(".", 2)[2]
4119+
if suffix.startswith("linear_fc"):
4120+
fc_idx_str, tail = suffix.split(".", 1)
4121+
fc_num = int(fc_idx_str.replace("linear_fc", ""))
4122+
# Qwen3VL has linear_fc1 and linear_fc2
4123+
# Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
4124+
if fc_num == 1:
4125+
fc_idx = 0
4126+
elif fc_num == 2:
4127+
fc_idx = 2
4128+
else:
4129+
raise ValueError(f"unexpected fc index {fc_num} in {name}")
4130+
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}")
4131+
elif suffix.startswith("norm."):
4132+
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}")
4133+
else:
4134+
raise ValueError(f"Unexpected merger tensor: {name}")
4135+
return [(new_name, data_torch)]
4136+
4137+
if name == "visual.patch_embed.proj.weight":
4138+
# split Conv3D into Conv2Ds along temporal dimension
4139+
c1, c2, kt, _, _ = data_torch.shape
4140+
del c1, c2
4141+
if kt != 2:
4142+
raise ValueError("Current implementation only supports temporal_patch_size of 2")
4143+
return [
4144+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
4145+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
4146+
]
4147+
4148+
if name == "visual.patch_embed.proj.bias":
4149+
# Include the bias - it's used by the C++ code
4150+
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]
4151+
4152+
if name.startswith("visual."):
4153+
return [(self.map_tensor_name(name), data_torch)]
4154+
4155+
# Fall back to parent class for other tensors
4156+
return super().modify_tensors(data_torch, name, bid)
4157+
4158+
4159+
@ModelBase.register("Qwen3VLForConditionalGeneration")
4160+
class Qwen3VLTextModel(Qwen3Model):
4161+
model_arch = gguf.MODEL_ARCH.QWEN3VL
4162+
4163+
def set_gguf_parameters(self):
4164+
super().set_gguf_parameters()
4165+
4166+
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4167+
text_config = self.hparams.get("text_config", {})
4168+
# rope_scaling is deprecated in V5, use rope_parameters instead
4169+
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
4170+
4171+
if rope_scaling.get("mrope_section"):
4172+
# mrope_section contains [time, height, width] dimensions
4173+
mrope_section = rope_scaling["mrope_section"]
4174+
# Pad to 4 dimensions [time, height, width, extra]
4175+
while len(mrope_section) < 4:
4176+
mrope_section.append(0)
4177+
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
4178+
4179+
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4180+
4181+
vision_config = self.hparams.get("vision_config", {})
4182+
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
4183+
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
4184+
4185+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4186+
# Skip vision tensors - they go in the mmproj file
4187+
if name.startswith("model.visual."):
4188+
return []
4189+
4190+
return super().modify_tensors(data_torch, name, bid)
4191+
4192+
4193+
@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
4194+
class Qwen3VLMoeTextModel(Qwen3MoeModel):
4195+
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
4196+
4197+
def set_gguf_parameters(self):
4198+
super().set_gguf_parameters()
4199+
4200+
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4201+
text_config = self.hparams.get("text_config", {})
4202+
# rope_scaling is deprecated in V5, use rope_parameters instead
4203+
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
4204+
4205+
if rope_scaling.get("mrope_section"):
4206+
# mrope_section contains [time, height, width] dimensions
4207+
mrope_section = rope_scaling["mrope_section"]
4208+
# Pad to 4 dimensions [time, height, width, extra]
4209+
while len(mrope_section) < 4:
4210+
mrope_section.append(0)
4211+
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
4212+
4213+
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4214+
4215+
vision_config = self.hparams.get("vision_config", {})
4216+
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
4217+
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
4218+
4219+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4220+
# Skip vision tensors - they go in the mmproj file
4221+
if name.startswith("model.visual."):
4222+
return []
4223+
4224+
return super().modify_tensors(data_torch, name, bid)
4225+
4226+
40104227
@ModelBase.register("GPT2LMHeadModel")
40114228
class GPT2Model(TextModel):
40124229
model_arch = gguf.MODEL_ARCH.GPT2
@@ -9535,6 +9752,37 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
95359752

95369753
return [] # skip other tensors
95379754

9755+
9756+
@ModelBase.register("CogVLMForCausalLM")
9757+
class CogVLMVisionModel(MmprojModel):
9758+
9759+
def set_gguf_parameters(self):
9760+
super().set_gguf_parameters()
9761+
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
9762+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.COGVLM)
9763+
9764+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
9765+
del bid # unused
9766+
9767+
if not name.startswith("model.vision."):
9768+
return []
9769+
9770+
return [(self.map_tensor_name(name), data_torch)]
9771+
9772+
9773+
@ModelBase.register("CogVLMForCausalLM")
9774+
class CogVLMModel(LlamaModel):
9775+
model_arch = gguf.MODEL_ARCH.COGVLM
9776+
9777+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
9778+
del bid # unused
9779+
9780+
# block vision tensors
9781+
if name.startswith("model.vision."):
9782+
return []
9783+
9784+
return [(self.map_tensor_name(name), data_torch)]
9785+
95389786
###### CONVERSION LOGIC ######
95399787

95409788

ggml/include/ggml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@
242242
#define GGML_ROPE_TYPE_NEOX 2
243243
#define GGML_ROPE_TYPE_MROPE 8
244244
#define GGML_ROPE_TYPE_VISION 24
245+
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
245246

246247
#define GGML_MROPE_SECTIONS 4
247248

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
16131613
chunk_size = 64;
16141614
}
16151615

1616-
#if defined(__aarch64__)
1617-
// disable for ARM
1618-
const bool disable_chunking = true;
1619-
#else
16201616
// disable for NUMA
16211617
const bool disable_chunking = ggml_is_numa();
1622-
#endif // defined(__aarch64__)
16231618

16241619
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
16251620
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;

0 commit comments

Comments
 (0)