Skip to content

Commit 3a3c9ae

Browse files
committed
Implement OLMoE architecture
1 parent bd35cb0 commit 3a3c9ae

File tree

5 files changed

+298
-15
lines changed

5 files changed

+298
-15
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Typically finetunes of the base models below are supported as well.
7777
- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
7878
- [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B)
7979
- [x] [OLMo](https://allenai.org/olmo)
80+
- [x] [OLMoE](https://huggingface.co/allenai/OLMoE-1B-7B-0924)
8081
- [x] [Granite models](https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330)
8182
- [x] [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) + [Pythia](https://github.com/EleutherAI/pythia)
8283
- [x] [Snowflake-Arctic MoE](https://huggingface.co/collections/Snowflake/arctic-66290090abe542894a5ac520)

convert_hf_to_gguf.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,6 +2944,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
29442944
return [(self.map_tensor_name(name), data_torch)]
29452945

29462946

2947+
@Model.register("OlmoeForCausalLM")
2948+
class OlmoeModel(Model):
2949+
model_arch = gguf.MODEL_ARCH.OLMOE
2950+
2951+
def set_gguf_parameters(self):
2952+
super().set_gguf_parameters()
2953+
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
2954+
if (n_experts := self.hparams.get("num_experts")) is not None:
2955+
self.gguf_writer.add_expert_count(n_experts)
2956+
2957+
_experts: list[dict[str, Tensor]] | None = None
2958+
2959+
# Copied from: Qwen2MoeModel
2960+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2961+
# process the experts separately
2962+
if name.find("experts") != -1:
2963+
n_experts = self.hparams["num_experts"]
2964+
assert bid is not None
2965+
2966+
if self._experts is None:
2967+
self._experts = [{} for _ in range(self.block_count)]
2968+
2969+
self._experts[bid][name] = data_torch
2970+
2971+
if len(self._experts[bid]) >= n_experts * 3:
2972+
tensors: list[tuple[str, Tensor]] = []
2973+
2974+
# merge the experts into a single 3d tensor
2975+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
2976+
datas: list[Tensor] = []
2977+
2978+
for xid in range(n_experts):
2979+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
2980+
datas.append(self._experts[bid][ename])
2981+
del self._experts[bid][ename]
2982+
2983+
data_torch = torch.stack(datas, dim=0)
2984+
2985+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
2986+
2987+
new_name = self.map_tensor_name(merged_name)
2988+
2989+
tensors.append((new_name, data_torch))
2990+
return tensors
2991+
else:
2992+
return []
2993+
2994+
return [(self.map_tensor_name(name), data_torch)]
2995+
2996+
# Copied from: Qwen2MoeModel
2997+
def prepare_tensors(self):
2998+
super().prepare_tensors()
2999+
3000+
if self._experts is not None:
3001+
# flatten `list[dict[str, Tensor]]` into `list[str]`
3002+
experts = [k for d in self._experts for k in d.keys()]
3003+
if len(experts) > 0:
3004+
raise ValueError(f"Unprocessed experts: {experts}")
3005+
3006+
29473007
@Model.register("JinaBertModel", "JinaBertForMaskedLM")
29483008
class JinaBertV2Model(BertModel):
29493009
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class MODEL_ARCH(IntEnum):
219219
COMMAND_R = auto()
220220
DBRX = auto()
221221
OLMO = auto()
222+
OLMOE = auto()
222223
OPENELM = auto()
223224
ARCTIC = auto()
224225
DEEPSEEK2 = auto()
@@ -373,6 +374,7 @@ class MODEL_TENSOR(IntEnum):
373374
MODEL_ARCH.COMMAND_R: "command-r",
374375
MODEL_ARCH.DBRX: "dbrx",
375376
MODEL_ARCH.OLMO: "olmo",
377+
MODEL_ARCH.OLMOE: "olmoe",
376378
MODEL_ARCH.OPENELM: "openelm",
377379
MODEL_ARCH.ARCTIC: "arctic",
378380
MODEL_ARCH.DEEPSEEK2: "deepseek2",
@@ -1008,6 +1010,23 @@ class MODEL_TENSOR(IntEnum):
10081010
MODEL_TENSOR.FFN_DOWN,
10091011
MODEL_TENSOR.FFN_UP,
10101012
],
1013+
MODEL_ARCH.OLMOE: [
1014+
MODEL_TENSOR.TOKEN_EMBD,
1015+
MODEL_TENSOR.OUTPUT_NORM,
1016+
MODEL_TENSOR.OUTPUT,
1017+
MODEL_TENSOR.ATTN_OUT,
1018+
MODEL_TENSOR.ATTN_Q,
1019+
MODEL_TENSOR.ATTN_K,
1020+
MODEL_TENSOR.ATTN_V,
1021+
MODEL_TENSOR.ATTN_NORM,
1022+
MODEL_TENSOR.ATTN_Q_NORM,
1023+
MODEL_TENSOR.ATTN_K_NORM,
1024+
MODEL_TENSOR.FFN_NORM,
1025+
MODEL_TENSOR.FFN_GATE_INP,
1026+
MODEL_TENSOR.FFN_GATE_EXP,
1027+
MODEL_TENSOR.FFN_UP_EXP,
1028+
MODEL_TENSOR.FFN_DOWN_EXP,
1029+
],
10111030
MODEL_ARCH.OPENELM: [
10121031
MODEL_TENSOR.TOKEN_EMBD,
10131032
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron
16+
"model.embed_tokens", # llama-hf nemotron olmoe
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -54,7 +54,7 @@ class TensorNameMap:
5454
# Output
5555
MODEL_TENSOR.OUTPUT: (
5656
"embed_out", # gptneox
57-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
57+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
5858
"output", # llama-pth bloom internlm2
5959
"word_embeddings_for_head", # persimmon
6060
"lm_head.linear", # phi2
@@ -66,7 +66,7 @@ class TensorNameMap:
6666
MODEL_TENSOR.OUTPUT_NORM: (
6767
"gpt_neox.final_layer_norm", # gptneox
6868
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
69-
"model.norm", # llama-hf baichuan internlm2
69+
"model.norm", # llama-hf baichuan internlm2 olmoe
7070
"norm", # llama-pth
7171
"transformer.norm_f", # mpt dbrx
7272
"ln_f", # refact bloom qwen gpt2
@@ -98,7 +98,7 @@ class TensorNameMap:
9898
"transformer.h.{bid}.input_layernorm", # falcon7b
9999
"h.{bid}.input_layernorm", # bloom
100100
"transformer.h.{bid}.ln_mlp", # falcon40b
101-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron
101+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
102102
"layers.{bid}.attention_norm", # llama-pth
103103
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
104104
"model.layers.{bid}.ln1", # yi
@@ -142,7 +142,7 @@ class TensorNameMap:
142142

143143
# Attention query
144144
MODEL_TENSOR.ATTN_Q: (
145-
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron
145+
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
146146
"layers.{bid}.attention.wq", # llama-pth
147147
"encoder.layer.{bid}.attention.self.query", # bert
148148
"transformer.h.{bid}.attn.q_proj", # gpt-j
@@ -154,7 +154,7 @@ class TensorNameMap:
154154

155155
# Attention key
156156
MODEL_TENSOR.ATTN_K: (
157-
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron
157+
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
158158
"layers.{bid}.attention.wk", # llama-pth
159159
"encoder.layer.{bid}.attention.self.key", # bert
160160
"transformer.h.{bid}.attn.k_proj", # gpt-j
@@ -167,7 +167,7 @@ class TensorNameMap:
167167

168168
# Attention value
169169
MODEL_TENSOR.ATTN_V: (
170-
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron
170+
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
171171
"layers.{bid}.attention.wv", # llama-pth
172172
"encoder.layer.{bid}.attention.self.value", # bert
173173
"transformer.h.{bid}.attn.v_proj", # gpt-j
@@ -185,7 +185,7 @@ class TensorNameMap:
185185
"transformer.blocks.{bid}.attn.out_proj", # mpt
186186
"transformer.h.{bid}.self_attention.dense", # falcon
187187
"h.{bid}.self_attention.dense", # bloom
188-
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron
188+
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
189189
"layers.{bid}.attention.wo", # llama-pth
190190
"encoder.layer.{bid}.attention.output.dense", # bert
191191
"transformer.h.{bid}.attn.out_proj", # gpt-j
@@ -229,7 +229,7 @@ class TensorNameMap:
229229
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
230230
"h.{bid}.post_attention_layernorm", # bloom
231231
"transformer.blocks.{bid}.norm_2", # mpt
232-
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
232+
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
233233
"layers.{bid}.ffn_norm", # llama-pth
234234
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
235235
"model.layers.{bid}.ln2", # yi
@@ -253,7 +253,7 @@ class TensorNameMap:
253253
MODEL_TENSOR.FFN_GATE_INP: (
254254
"layers.{bid}.feed_forward.gate", # mixtral
255255
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
256-
"model.layers.{bid}.mlp.gate", # qwen2moe
256+
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
257257
"transformer.decoder_layer.{bid}.router", # Grok
258258
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
259259
),
@@ -295,7 +295,7 @@ class TensorNameMap:
295295
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
296296
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
297297
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
298-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
298+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
299299
),
300300

301301
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -327,7 +327,7 @@ class TensorNameMap:
327327
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
328328
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
329329
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
330-
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged)
330+
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
331331
),
332332

333333
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -367,7 +367,7 @@ class TensorNameMap:
367367
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
368368
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
369369
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
370-
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe (merged)
370+
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
371371
),
372372

373373
MODEL_TENSOR.FFN_DOWN_SHEXP: (
@@ -378,7 +378,7 @@ class TensorNameMap:
378378
MODEL_TENSOR.ATTN_Q_NORM: (
379379
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
380380
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
381-
"model.layers.{bid}.self_attn.q_norm", # cohere
381+
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe
382382
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
383383
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
384384
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -387,7 +387,7 @@ class TensorNameMap:
387387
MODEL_TENSOR.ATTN_K_NORM: (
388388
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
389389
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
390-
"model.layers.{bid}.self_attn.k_norm", # cohere
390+
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe
391391
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
392392
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
393393
"transformer.layers.{bid}.attn.k_norm", # openelm

0 commit comments

Comments
 (0)