Skip to content

Commit 1805a54

Browse files
2015arorasNexesenex
authored andcommitted
llama : support OLMoE (ggml-org#9462)
1 parent 46ce6fc commit 1805a54

File tree

4 files changed

+297
-15
lines changed

4 files changed

+297
-15
lines changed

convert_hf_to_gguf.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2912,6 +2912,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
29122912
return [(self.map_tensor_name(name), data_torch)]
29132913

29142914

2915+
@Model.register("OlmoeForCausalLM")
2916+
class OlmoeModel(Model):
2917+
model_arch = gguf.MODEL_ARCH.OLMOE
2918+
2919+
def set_gguf_parameters(self):
2920+
super().set_gguf_parameters()
2921+
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
2922+
if (n_experts := self.hparams.get("num_experts")) is not None:
2923+
self.gguf_writer.add_expert_count(n_experts)
2924+
2925+
_experts: list[dict[str, Tensor]] | None = None
2926+
2927+
# Copied from: Qwen2MoeModel
2928+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2929+
# process the experts separately
2930+
if name.find("experts") != -1:
2931+
n_experts = self.hparams["num_experts"]
2932+
assert bid is not None
2933+
2934+
if self._experts is None:
2935+
self._experts = [{} for _ in range(self.block_count)]
2936+
2937+
self._experts[bid][name] = data_torch
2938+
2939+
if len(self._experts[bid]) >= n_experts * 3:
2940+
tensors: list[tuple[str, Tensor]] = []
2941+
2942+
# merge the experts into a single 3d tensor
2943+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
2944+
datas: list[Tensor] = []
2945+
2946+
for xid in range(n_experts):
2947+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
2948+
datas.append(self._experts[bid][ename])
2949+
del self._experts[bid][ename]
2950+
2951+
data_torch = torch.stack(datas, dim=0)
2952+
2953+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
2954+
2955+
new_name = self.map_tensor_name(merged_name)
2956+
2957+
tensors.append((new_name, data_torch))
2958+
return tensors
2959+
else:
2960+
return []
2961+
2962+
return [(self.map_tensor_name(name), data_torch)]
2963+
2964+
# Copied from: Qwen2MoeModel
2965+
def prepare_tensors(self):
2966+
super().prepare_tensors()
2967+
2968+
if self._experts is not None:
2969+
# flatten `list[dict[str, Tensor]]` into `list[str]`
2970+
experts = [k for d in self._experts for k in d.keys()]
2971+
if len(experts) > 0:
2972+
raise ValueError(f"Unprocessed experts: {experts}")
2973+
2974+
29152975
@Model.register("JinaBertModel", "JinaBertForMaskedLM")
29162976
class JinaBertV2Model(BertModel):
29172977
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class MODEL_ARCH(IntEnum):
213213
COMMAND_R = auto()
214214
DBRX = auto()
215215
OLMO = auto()
216+
OLMOE = auto()
216217
OPENELM = auto()
217218
ARCTIC = auto()
218219
DEEPSEEK2 = auto()
@@ -344,6 +345,7 @@ class MODEL_TENSOR(IntEnum):
344345
MODEL_ARCH.COMMAND_R: "command-r",
345346
MODEL_ARCH.DBRX: "dbrx",
346347
MODEL_ARCH.OLMO: "olmo",
348+
MODEL_ARCH.OLMOE: "olmoe",
347349
MODEL_ARCH.OPENELM: "openelm",
348350
MODEL_ARCH.ARCTIC: "arctic",
349351
MODEL_ARCH.DEEPSEEK2: "deepseek2",
@@ -942,6 +944,23 @@ class MODEL_TENSOR(IntEnum):
942944
MODEL_TENSOR.FFN_DOWN,
943945
MODEL_TENSOR.FFN_UP,
944946
],
947+
MODEL_ARCH.OLMOE: [
948+
MODEL_TENSOR.TOKEN_EMBD,
949+
MODEL_TENSOR.OUTPUT_NORM,
950+
MODEL_TENSOR.OUTPUT,
951+
MODEL_TENSOR.ATTN_OUT,
952+
MODEL_TENSOR.ATTN_Q,
953+
MODEL_TENSOR.ATTN_K,
954+
MODEL_TENSOR.ATTN_V,
955+
MODEL_TENSOR.ATTN_NORM,
956+
MODEL_TENSOR.ATTN_Q_NORM,
957+
MODEL_TENSOR.ATTN_K_NORM,
958+
MODEL_TENSOR.FFN_NORM,
959+
MODEL_TENSOR.FFN_GATE_INP,
960+
MODEL_TENSOR.FFN_GATE_EXP,
961+
MODEL_TENSOR.FFN_UP_EXP,
962+
MODEL_TENSOR.FFN_DOWN_EXP,
963+
],
945964
MODEL_ARCH.OPENELM: [
946965
MODEL_TENSOR.TOKEN_EMBD,
947966
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron
16+
"model.embed_tokens", # llama-hf nemotron olmoe
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -52,7 +52,7 @@ class TensorNameMap:
5252
# Output
5353
MODEL_TENSOR.OUTPUT: (
5454
"embed_out", # gptneox
55-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
55+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
5656
"output", # llama-pth bloom internlm2
5757
"word_embeddings_for_head", # persimmon
5858
"lm_head.linear", # phi2
@@ -63,7 +63,7 @@ class TensorNameMap:
6363
MODEL_TENSOR.OUTPUT_NORM: (
6464
"gpt_neox.final_layer_norm", # gptneox
6565
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
66-
"model.norm", # llama-hf baichuan internlm2
66+
"model.norm", # llama-hf baichuan internlm2 olmoe
6767
"norm", # llama-pth
6868
"transformer.norm_f", # mpt dbrx
6969
"ln_f", # refact bloom qwen gpt2
@@ -94,7 +94,7 @@ class TensorNameMap:
9494
"transformer.h.{bid}.input_layernorm", # falcon7b
9595
"h.{bid}.input_layernorm", # bloom
9696
"transformer.h.{bid}.ln_mlp", # falcon40b
97-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron
97+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
9898
"layers.{bid}.attention_norm", # llama-pth
9999
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
100100
"model.layers.{bid}.ln1", # yi
@@ -136,7 +136,7 @@ class TensorNameMap:
136136

137137
# Attention query
138138
MODEL_TENSOR.ATTN_Q: (
139-
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron
139+
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
140140
"layers.{bid}.attention.wq", # llama-pth
141141
"encoder.layer.{bid}.attention.self.query", # bert
142142
"transformer.h.{bid}.attn.q_proj", # gpt-j
@@ -148,7 +148,7 @@ class TensorNameMap:
148148

149149
# Attention key
150150
MODEL_TENSOR.ATTN_K: (
151-
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron
151+
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
152152
"layers.{bid}.attention.wk", # llama-pth
153153
"encoder.layer.{bid}.attention.self.key", # bert
154154
"transformer.h.{bid}.attn.k_proj", # gpt-j
@@ -161,7 +161,7 @@ class TensorNameMap:
161161

162162
# Attention value
163163
MODEL_TENSOR.ATTN_V: (
164-
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron
164+
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
165165
"layers.{bid}.attention.wv", # llama-pth
166166
"encoder.layer.{bid}.attention.self.value", # bert
167167
"transformer.h.{bid}.attn.v_proj", # gpt-j
@@ -179,7 +179,7 @@ class TensorNameMap:
179179
"transformer.blocks.{bid}.attn.out_proj", # mpt
180180
"transformer.h.{bid}.self_attention.dense", # falcon
181181
"h.{bid}.self_attention.dense", # bloom
182-
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron
182+
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
183183
"layers.{bid}.attention.wo", # llama-pth
184184
"encoder.layer.{bid}.attention.output.dense", # bert
185185
"transformer.h.{bid}.attn.out_proj", # gpt-j
@@ -223,7 +223,7 @@ class TensorNameMap:
223223
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
224224
"h.{bid}.post_attention_layernorm", # bloom
225225
"transformer.blocks.{bid}.norm_2", # mpt
226-
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
226+
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
227227
"layers.{bid}.ffn_norm", # llama-pth
228228
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
229229
"model.layers.{bid}.ln2", # yi
@@ -247,7 +247,7 @@ class TensorNameMap:
247247
MODEL_TENSOR.FFN_GATE_INP: (
248248
"layers.{bid}.feed_forward.gate", # mixtral
249249
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
250-
"model.layers.{bid}.mlp.gate", # qwen2moe
250+
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
251251
"transformer.decoder_layer.{bid}.router", # Grok
252252
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
253253
),
@@ -289,7 +289,7 @@ class TensorNameMap:
289289
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
290290
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
291291
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
292-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
292+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
293293
),
294294

295295
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -321,7 +321,7 @@ class TensorNameMap:
321321
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
322322
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
323323
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
324-
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged)
324+
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
325325
),
326326

327327
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -361,7 +361,7 @@ class TensorNameMap:
361361
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
362362
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
363363
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
364-
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe (merged)
364+
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
365365
),
366366

367367
MODEL_TENSOR.FFN_DOWN_SHEXP: (
@@ -372,7 +372,7 @@ class TensorNameMap:
372372
MODEL_TENSOR.ATTN_Q_NORM: (
373373
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
374374
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
375-
"model.layers.{bid}.self_attn.q_norm", # cohere
375+
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe
376376
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
377377
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
378378
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -381,7 +381,7 @@ class TensorNameMap:
381381
MODEL_TENSOR.ATTN_K_NORM: (
382382
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
383383
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
384-
"model.layers.{bid}.self_attn.k_norm", # cohere
384+
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe
385385
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
386386
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
387387
"transformer.layers.{bid}.attn.k_norm", # openelm

0 commit comments

Comments
 (0)