Skip to content

Commit da8cf83

Browse files
committed
Add deepseek v1 arch & gigachat template
1 parent 43041d2 commit da8cf83

File tree

5 files changed

+428
-1
lines changed

5 files changed

+428
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
664664
if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65":
665665
# ref: https://huggingface.co/sentence-transformers/stsb-roberta-base
666666
res = "roberta-bpe"
667+
if chkhsh == "ad851be1dba641f2e3711822f816db2c265f788b37c63b4e1aeacb9ee92de8eb":
668+
# ref: https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct
669+
res = "gigachat"
667670

668671
if res is None:
669672
logger.warning("\n")
@@ -3483,6 +3486,102 @@ def prepare_tensors(self):
34833486
raise ValueError(f"Unprocessed experts: {experts}")
34843487

34853488

3489+
@Model.register("DeepseekForCausalLM")
3490+
class DeepseekModel(Model):
3491+
model_arch = gguf.MODEL_ARCH.DEEPSEEK
3492+
3493+
def set_vocab(self):
3494+
try:
3495+
self._set_vocab_sentencepiece()
3496+
except FileNotFoundError:
3497+
self._set_vocab_gpt2()
3498+
3499+
def set_gguf_parameters(self):
3500+
super().set_gguf_parameters()
3501+
hparams = self.hparams
3502+
if "head_dim" in hparams:
3503+
rope_dim = hparams["head_dim"]
3504+
else:
3505+
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
3506+
self.gguf_writer.add_rope_dimension_count(rope_dim)
3507+
3508+
# if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
3509+
# if self.hparams["rope_scaling"].get("type") == "linear":
3510+
# self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
3511+
# self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
3512+
3513+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
3514+
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
3515+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
3516+
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
3517+
self.gguf_writer.add_expert_weights_scale(1.0)
3518+
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
3519+
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
3520+
3521+
_experts: list[dict[str, Tensor]] | None = None
3522+
3523+
@staticmethod
3524+
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
3525+
if n_head_kv is not None and n_head != n_head_kv:
3526+
n_head = n_head_kv
3527+
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
3528+
.swapaxes(1, 2)
3529+
.reshape(weights.shape))
3530+
3531+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3532+
n_head = self.hparams["num_attention_heads"]
3533+
n_kv_head = self.hparams.get("num_key_value_heads")
3534+
3535+
if name.endswith(("q_proj.weight", "q_proj.bias")):
3536+
data_torch = DeepseekModel.permute(data_torch, n_head, n_head)
3537+
if name.endswith(("k_proj.weight", "k_proj.bias")):
3538+
data_torch = DeepseekModel.permute(data_torch, n_head, n_kv_head)
3539+
3540+
# process the experts separately
3541+
if name.find("mlp.experts") != -1:
3542+
n_experts = self.hparams["n_routed_experts"]
3543+
assert bid is not None
3544+
3545+
if self._experts is None:
3546+
self._experts = [{} for _ in range(self.block_count)]
3547+
3548+
self._experts[bid][name] = data_torch
3549+
3550+
if len(self._experts[bid]) >= n_experts * 3:
3551+
tensors: list[tuple[str, Tensor]] = []
3552+
3553+
# merge the experts into a single 3d tensor
3554+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
3555+
datas: list[Tensor] = []
3556+
3557+
for xid in range(n_experts):
3558+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
3559+
datas.append(self._experts[bid][ename])
3560+
del self._experts[bid][ename]
3561+
3562+
data_torch = torch.stack(datas, dim=0)
3563+
3564+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
3565+
3566+
new_name = self.map_tensor_name(merged_name)
3567+
3568+
tensors.append((new_name, data_torch))
3569+
return tensors
3570+
else:
3571+
return []
3572+
3573+
return [(self.map_tensor_name(name), data_torch)]
3574+
3575+
def prepare_tensors(self):
3576+
super().prepare_tensors()
3577+
3578+
if self._experts is not None:
3579+
# flatten `list[dict[str, Tensor]]` into `list[str]`
3580+
experts = [k for d in self._experts for k in d.keys()]
3581+
if len(experts) > 0:
3582+
raise ValueError(f"Unprocessed experts: {experts}")
3583+
3584+
34863585
@Model.register("T5WithLMHeadModel")
34873586
@Model.register("T5ForConditionalGeneration")
34883587
@Model.register("MT5ForConditionalGeneration")

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class TOKENIZER_TYPE(IntEnum):
104104
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
105105
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", },
106106
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
107+
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
107108
]
108109

109110

gguf-py/gguf/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ class MODEL_ARCH(IntEnum):
248248
OPENELM = auto()
249249
ARCTIC = auto()
250250
DEEPSEEK2 = auto()
251+
DEEPSEEK = auto()
251252
CHATGLM = auto()
252253
BITNET = auto()
253254
T5 = auto()
@@ -410,6 +411,7 @@ class MODEL_TENSOR(IntEnum):
410411
MODEL_ARCH.OPENELM: "openelm",
411412
MODEL_ARCH.ARCTIC: "arctic",
412413
MODEL_ARCH.DEEPSEEK2: "deepseek2",
414+
MODEL_ARCH.DEEPSEEK: "deepseek",
413415
MODEL_ARCH.CHATGLM: "chatglm",
414416
MODEL_ARCH.BITNET: "bitnet",
415417
MODEL_ARCH.T5: "t5",
@@ -1141,6 +1143,31 @@ class MODEL_TENSOR(IntEnum):
11411143
MODEL_TENSOR.FFN_DOWN_EXP,
11421144
MODEL_TENSOR.FFN_UP_EXP,
11431145
],
1146+
###############
1147+
MODEL_ARCH.DEEPSEEK: [
1148+
MODEL_TENSOR.TOKEN_EMBD,
1149+
MODEL_TENSOR.OUTPUT_NORM,
1150+
MODEL_TENSOR.OUTPUT,
1151+
MODEL_TENSOR.ROPE_FREQS,
1152+
MODEL_TENSOR.ATTN_NORM,
1153+
MODEL_TENSOR.ATTN_Q,
1154+
MODEL_TENSOR.ATTN_K,
1155+
MODEL_TENSOR.ATTN_V,
1156+
MODEL_TENSOR.ATTN_OUT,
1157+
MODEL_TENSOR.ATTN_ROT_EMBD,
1158+
MODEL_TENSOR.FFN_GATE_INP,
1159+
MODEL_TENSOR.FFN_NORM,
1160+
MODEL_TENSOR.FFN_GATE,
1161+
MODEL_TENSOR.FFN_DOWN,
1162+
MODEL_TENSOR.FFN_UP,
1163+
MODEL_TENSOR.FFN_GATE_EXP,
1164+
MODEL_TENSOR.FFN_DOWN_EXP,
1165+
MODEL_TENSOR.FFN_UP_EXP,
1166+
MODEL_TENSOR.FFN_GATE_SHEXP,
1167+
MODEL_TENSOR.FFN_DOWN_SHEXP,
1168+
MODEL_TENSOR.FFN_UP_SHEXP,
1169+
],
1170+
###############
11441171
MODEL_ARCH.DEEPSEEK2: [
11451172
MODEL_TENSOR.TOKEN_EMBD,
11461173
MODEL_TENSOR.OUTPUT_NORM,
@@ -1363,6 +1390,10 @@ class MODEL_TENSOR(IntEnum):
13631390
MODEL_TENSOR.ROPE_FREQS,
13641391
MODEL_TENSOR.ATTN_ROT_EMBD,
13651392
],
1393+
MODEL_ARCH.DEEPSEEK: [
1394+
MODEL_TENSOR.ROPE_FREQS,
1395+
MODEL_TENSOR.ATTN_ROT_EMBD,
1396+
],
13661397
MODEL_ARCH.DEEPSEEK2: [
13671398
MODEL_TENSOR.ROPE_FREQS,
13681399
MODEL_TENSOR.ATTN_ROT_EMBD,

0 commit comments

Comments
 (0)