Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 41 additions & 28 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,11 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type


class TextModel(ModelBase):
model_type = ModelType.TEXT

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hf_arch = get_model_architecture(self.hparams, self.model_type)

if "text_config" in self.hparams:
# move the text_config to the root level
Expand Down Expand Up @@ -1073,10 +1076,36 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None:
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])

def _try_set_pooling_type(self) -> None:
# get pooling path
pooling_path = None
module_path = self.dir_model / "modules.json"
if module_path.is_file():
with open(module_path, encoding="utf-8") as f:
modules = json.load(f)
for mod in modules:
if mod["type"] == "sentence_transformers.models.Pooling":
pooling_path = mod["path"]
break

# get pooling type
if pooling_path is not None:
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
pooling = json.load(f)
if pooling["pooling_mode_mean_tokens"]:
pooling_type = gguf.PoolingType.MEAN
elif pooling["pooling_mode_cls_token"]:
pooling_type = gguf.PoolingType.CLS
elif pooling["pooling_mode_lasttoken"]:
pooling_type = gguf.PoolingType.LAST
else:
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type)


class VisionModel(ModelBase):
model_type = ModelType.VISION
model_arch = gguf.MODEL_ARCH.CLIP_VISION
n_text_embd = 0
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]

Expand Down Expand Up @@ -2538,7 +2567,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_file_type(self.ftype)


@ModelBase.register("Qwen2ForCausalLM")
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
class Qwen2Model(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2

Expand All @@ -2550,12 +2579,18 @@ def set_vocab(self):

def set_gguf_parameters(self):
super().set_gguf_parameters()
self._try_set_pooling_type()
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if self.hf_arch == "Qwen2Model":
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLModel(TextModel):
Expand Down Expand Up @@ -3316,29 +3351,7 @@ def __init__(self, *args, **kwargs):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_causal_attention(False)

# get pooling path
pooling_path = None
module_path = self.dir_model / "modules.json"
if module_path.is_file():
with open(module_path, encoding="utf-8") as f:
modules = json.load(f)
for mod in modules:
if mod["type"] == "sentence_transformers.models.Pooling":
pooling_path = mod["path"]
break

# get pooling type
if pooling_path is not None:
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
pooling = json.load(f)
if pooling["pooling_mode_mean_tokens"]:
pooling_type = gguf.PoolingType.MEAN
elif pooling["pooling_mode_cls_token"]:
pooling_type = gguf.PoolingType.CLS
else:
raise NotImplementedError("Only MEAN and CLS pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type)
self._try_set_pooling_type()

def set_vocab(self):
tokens, toktypes, tokpre = self.get_vocab_base()
Expand Down Expand Up @@ -5877,8 +5890,7 @@ def split_str_to_n_bytes(split_str: str) -> int:
return n


def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
Expand Down Expand Up @@ -5949,7 +5961,8 @@ def main() -> None:
with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
model_architecture = get_model_architecture(dir_model, model_type)
hparams = ModelBase.load_hparams(dir_model)
model_architecture = get_model_architecture(hparams, model_type)
logger.info(f"Model architecture: {model_architecture}")
try:
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,6 +2032,8 @@ class PoolingType(IntEnum):
NONE = 0
MEAN = 1
CLS = 2
LAST = 3
RANK = 4


class GGMLQuantizationType(IntEnum):
Expand Down
1 change: 1 addition & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// fall through
case LLM_ARCH_QWEN2:
{
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;
Expand Down
Loading