Skip to content

Commit a88ad00

Browse files
authored
llama : add OLMo November 2024 support (ggml-org#10394)
* Add OLMo November 2024 constants * Add OLMo November 2024 converter * Add loading of OLMo November 2024 tensors and hyper parameters * Add building of OLMo November 2024 model
1 parent 2a1507c commit a88ad00

File tree

4 files changed

+223
-14
lines changed

4 files changed

+223
-14
lines changed

convert_hf_to_gguf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,6 +3040,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
30403040
return [(self.map_tensor_name(name), data_torch)]
30413041

30423042

3043+
@Model.register("Olmo1124ForCausalLM")
3044+
class Olmo1124Model(Model):
3045+
model_arch = gguf.MODEL_ARCH.OLMO_1124
3046+
3047+
30433048
@Model.register("OlmoeForCausalLM")
30443049
class OlmoeModel(Model):
30453050
model_arch = gguf.MODEL_ARCH.OLMOE

gguf-py/gguf/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class MODEL_ARCH(IntEnum):
243243
COMMAND_R = auto()
244244
DBRX = auto()
245245
OLMO = auto()
246+
OLMO_1124 = auto()
246247
OLMOE = auto()
247248
OPENELM = auto()
248249
ARCTIC = auto()
@@ -404,6 +405,7 @@ class MODEL_TENSOR(IntEnum):
404405
MODEL_ARCH.COMMAND_R: "command-r",
405406
MODEL_ARCH.DBRX: "dbrx",
406407
MODEL_ARCH.OLMO: "olmo",
408+
MODEL_ARCH.OLMO_1124: "olmo_1124",
407409
MODEL_ARCH.OLMOE: "olmoe",
408410
MODEL_ARCH.OPENELM: "openelm",
409411
MODEL_ARCH.ARCTIC: "arctic",
@@ -1069,6 +1071,22 @@ class MODEL_TENSOR(IntEnum):
10691071
MODEL_TENSOR.FFN_DOWN,
10701072
MODEL_TENSOR.FFN_UP,
10711073
],
1074+
MODEL_ARCH.OLMO_1124: [
1075+
MODEL_TENSOR.TOKEN_EMBD,
1076+
MODEL_TENSOR.OUTPUT_NORM,
1077+
MODEL_TENSOR.OUTPUT,
1078+
MODEL_TENSOR.ATTN_Q,
1079+
MODEL_TENSOR.ATTN_K,
1080+
MODEL_TENSOR.ATTN_V,
1081+
MODEL_TENSOR.ATTN_OUT,
1082+
MODEL_TENSOR.ATTN_POST_NORM,
1083+
MODEL_TENSOR.ATTN_Q_NORM,
1084+
MODEL_TENSOR.ATTN_K_NORM,
1085+
MODEL_TENSOR.FFN_POST_NORM,
1086+
MODEL_TENSOR.FFN_GATE,
1087+
MODEL_TENSOR.FFN_DOWN,
1088+
MODEL_TENSOR.FFN_UP,
1089+
],
10721090
MODEL_ARCH.OLMOE: [
10731091
MODEL_TENSOR.TOKEN_EMBD,
10741092
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron olmoe
16+
"model.embed_tokens", # llama-hf nemotron olmoe olmo_1124
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -54,7 +54,7 @@ class TensorNameMap:
5454
# Output
5555
MODEL_TENSOR.OUTPUT: (
5656
"embed_out", # gptneox
57-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
57+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo_1124
5858
"output", # llama-pth bloom internlm2
5959
"word_embeddings_for_head", # persimmon
6060
"lm_head.linear", # phi2
@@ -66,7 +66,7 @@ class TensorNameMap:
6666
MODEL_TENSOR.OUTPUT_NORM: (
6767
"gpt_neox.final_layer_norm", # gptneox
6868
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
69-
"model.norm", # llama-hf baichuan internlm2 olmoe
69+
"model.norm", # llama-hf baichuan internlm2 olmoe olmo_1124
7070
"norm", # llama-pth
7171
"transformer.norm_f", # mpt dbrx
7272
"ln_f", # refact bloom qwen gpt2
@@ -145,7 +145,7 @@ class TensorNameMap:
145145

146146
# Attention query
147147
MODEL_TENSOR.ATTN_Q: (
148-
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
148+
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo_1124
149149
"layers.{bid}.attention.wq", # llama-pth
150150
"encoder.layer.{bid}.attention.self.query", # bert
151151
"transformer.h.{bid}.attn.q_proj", # gpt-j
@@ -157,7 +157,7 @@ class TensorNameMap:
157157

158158
# Attention key
159159
MODEL_TENSOR.ATTN_K: (
160-
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
160+
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo_1124
161161
"layers.{bid}.attention.wk", # llama-pth
162162
"encoder.layer.{bid}.attention.self.key", # bert
163163
"transformer.h.{bid}.attn.k_proj", # gpt-j
@@ -170,7 +170,7 @@ class TensorNameMap:
170170

171171
# Attention value
172172
MODEL_TENSOR.ATTN_V: (
173-
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
173+
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo_1124
174174
"layers.{bid}.attention.wv", # llama-pth
175175
"encoder.layer.{bid}.attention.self.value", # bert
176176
"transformer.h.{bid}.attn.v_proj", # gpt-j
@@ -188,7 +188,7 @@ class TensorNameMap:
188188
"transformer.blocks.{bid}.attn.out_proj", # mpt
189189
"transformer.h.{bid}.self_attention.dense", # falcon
190190
"h.{bid}.self_attention.dense", # bloom
191-
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
191+
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo_1124
192192
"layers.{bid}.attention.wo", # llama-pth
193193
"encoder.layer.{bid}.attention.output.dense", # bert
194194
"transformer.h.{bid}.attn.out_proj", # gpt-j
@@ -215,7 +215,7 @@ class TensorNameMap:
215215
),
216216

217217
MODEL_TENSOR.ATTN_POST_NORM: (
218-
"model.layers.{bid}.post_attention_layernorm", # gemma2
218+
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo_1124
219219
),
220220

221221
# Rotary embeddings
@@ -250,7 +250,7 @@ class TensorNameMap:
250250

251251
# Post feed-forward norm
252252
MODEL_TENSOR.FFN_POST_NORM: (
253-
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
253+
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo_1124
254254
),
255255

256256
MODEL_TENSOR.FFN_GATE_INP: (
@@ -273,7 +273,7 @@ class TensorNameMap:
273273
"transformer.blocks.{bid}.ffn.up_proj", # mpt
274274
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
275275
"h.{bid}.mlp.dense_h_to_4h", # bloom
276-
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron
276+
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo_1124
277277
"layers.{bid}.feed_forward.w3", # llama-pth
278278
"encoder.layer.{bid}.intermediate.dense", # bert
279279
"transformer.h.{bid}.mlp.fc_in", # gpt-j
@@ -314,7 +314,7 @@ class TensorNameMap:
314314

315315
# Feed-forward gate
316316
MODEL_TENSOR.FFN_GATE: (
317-
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
317+
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo_1124
318318
"layers.{bid}.feed_forward.w1", # llama-pth
319319
"transformer.h.{bid}.mlp.w2", # qwen
320320
"transformer.h.{bid}.mlp.c_fc2", # jais
@@ -346,7 +346,7 @@ class TensorNameMap:
346346
"transformer.blocks.{bid}.ffn.down_proj", # mpt
347347
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
348348
"h.{bid}.mlp.dense_4h_to_h", # bloom
349-
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron
349+
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo_1124
350350
"layers.{bid}.feed_forward.w2", # llama-pth
351351
"encoder.layer.{bid}.output.dense", # bert
352352
"transformer.h.{bid}.mlp.fc_out", # gpt-j
@@ -383,7 +383,7 @@ class TensorNameMap:
383383
MODEL_TENSOR.ATTN_Q_NORM: (
384384
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
385385
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
386-
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon
386+
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo_1124
387387
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
388388
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
389389
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -392,7 +392,7 @@ class TensorNameMap:
392392
MODEL_TENSOR.ATTN_K_NORM: (
393393
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
394394
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
395-
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon
395+
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo_1124
396396
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
397397
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
398398
"transformer.layers.{bid}.attn.k_norm", # openelm

0 commit comments

Comments
 (0)