Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,12 +608,13 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
vocab_size = self.hparams.get("vocab_size", len(vocab))
assert max(vocab.values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.get_added_vocab()

added_tokens_decoder = tokenizer.added_tokens_decoder
Expand Down Expand Up @@ -2998,7 +2999,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
@ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel):
def set_gguf_parameters(self):
assert self.hparams_vision is not None
if isinstance(self.hparams_vision['image_size'], list):
self.hparams_vision['image_size'] = self.hparams_vision['image_size'][0]
if isinstance(self.hparams_vision['patch_size'], list):
self.hparams_vision['patch_size'] = self.hparams_vision['patch_size'][0]
super().set_gguf_parameters()

hparams = self.hparams
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
Expand All @@ -3022,14 +3029,30 @@ def tensor_force_quant(self, name, new_name, bid, n_dims):
return gguf.GGMLQuantizationType.F32
return False

def _mapping_interns1_name(self, name):
names_map = {
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
}
if name in names_map:
name = names_map[name]
return name

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("vision_model") or name.startswith("mlp"):
vision_prefix = ['vision_model', 'mlp', 'model.vision_tower', 'model.multi_modal_projector']
# deal with intern-s1 special case
name = self._mapping_interns1_name(name)
if any([name.startswith(prefix) for prefix in vision_prefix]):
# process visual tensors
# correct name
if name.startswith("vision_model"):
name = "vision_tower." + name
if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
if (".ls" in name or ".lambda_" in name or "position_embedding" in name) and not name.endswith(".weight"):
name += ".weight"
# split QKV tensors if needed
if ".qkv." in name:
Expand Down Expand Up @@ -3115,6 +3138,10 @@ def set_gguf_parameters(self):

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# skip visual tensors
return []
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
Expand Down Expand Up @@ -3168,6 +3195,47 @@ class Qwen3Model(Qwen2Model):
class Qwen3MoeModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3MOE

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hparams = ModelBase.load_hparams(self.dir_model)
self.origin_hf_arch = hparams.get('architectures', [None])[0]

def set_vocab(self):
# deal with intern-s1
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
self._set_vocab_interns1()
return

try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()

def _set_vocab_interns1(self):
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
additional_special_tokens = []
if special_tokens_map_file.is_file():
with open(special_tokens_map_file, encoding = 'utf-8') as f:
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
if tokenizer_cfg_file.is_file():
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
for token in additional_special_tokens:
if token in token2ids_map:
special_vocab._set_special_token(token, token2ids_map[token])
special_vocab._set_special_token('eos', 151645)
special_vocab._set_special_token("bos", 151643)
special_vocab.add_to_gguf(self.gguf_writer)


@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
Expand Down
15 changes: 15 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,11 +1054,13 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_EMBD_CLS: (
"vision_tower.vision_model.embeddings.class_embedding",
"model.vision_tower.embeddings.cls_token", # Intern-S1
"vision_model.class_embedding", # llama 4
),

MODEL_TENSOR.V_ENC_EMBD_PATCH: (
"vision_tower.vision_model.embeddings.patch_embedding",
"model.vision_tower.embeddings.patch_embeddings.projection", # Intern-S1
"vpm.embeddings.patch_embedding",
"model.vision_model.embeddings.patch_embedding", # SmolVLM
"vision_tower.patch_conv", # pixtral
Expand All @@ -1068,13 +1070,15 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_EMBD_POS: (
"vision_tower.vision_model.embeddings.position_embedding",
"model.vision_tower.embeddings.position_embeddings", # Intern-S1
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
"vision_model.positional_embedding_vlm", # llama 4
),

MODEL_TENSOR.V_ENC_ATTN_Q: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_tower.encoder.layer.{bid}.attention.q_proj", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
"vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
Expand All @@ -1084,10 +1088,12 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL
"model.vision_tower.encoder.layer.{bid}.attention.q_norm", # Intern-S1
),

MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_tower.encoder.layer.{bid}.attention.k_proj", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
"vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
Expand All @@ -1097,10 +1103,12 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL
"model.vision_tower.encoder.layer.{bid}.attention.k_norm", # Intern-S1
),

MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_tower.encoder.layer.{bid}.attention.v_proj", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
"vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
Expand All @@ -1111,6 +1119,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
"vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL
"model.vision_tower.encoder.layer.{bid}.layernorm_before", # Intern-S1
"vpm.encoder.layers.{bid}.layer_norm1",
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
Expand All @@ -1121,6 +1130,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_O: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
"model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
Expand All @@ -1131,6 +1141,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
"model.vision_tower.encoder.layer.{bid}.layernorm_after", # Intern-S1
"vpm.encoder.layers.{bid}.layer_norm2",
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
"vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
Expand All @@ -1140,6 +1151,7 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
"model.vision_tower.encoder.layer.{bid}.mlp.fc1", # Intern-S1
"vpm.encoder.layers.{bid}.mlp.fc1",
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
Expand All @@ -1155,6 +1167,7 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_FFN_DOWN: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
"model.vision_tower.encoder.layer.{bid}.mlp.fc2", # Intern-S1
"vpm.encoder.layers.{bid}.mlp.fc2",
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
Expand All @@ -1165,10 +1178,12 @@ class TensorNameMap:

MODEL_TENSOR.V_LAYER_SCALE_1: (
"vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL
"model.vision_tower.encoder.layer.{bid}.lambda_1", # Intern-S1
),

MODEL_TENSOR.V_LAYER_SCALE_2: (
"vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL
"model.vision_tower.encoder.layer.{bid}.lambda_2", # Intern-S1
),

MODEL_TENSOR.V_PRE_NORM: (
Expand Down