Skip to content

Commit 385cef2

Browse files
authored
[Model] Add support for OLMo architecture (#3046)
This PR add support for OLMo architecture. Additional support: add support for clip-qkv. Test: already tested on android(pixel 4) and cuda(setting tensor_parallel_shrads=2)
1 parent 86cf3f7 commit 385cef2

File tree

8 files changed

+902
-0
lines changed

8 files changed

+902
-0
lines changed

python/mlc_llm/conversation_template/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
llava,
2121
mistral,
2222
oasst,
23+
olmo,
2324
orion,
2425
phi,
2526
qwen2,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""OLMo default templates"""
2+
3+
from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders
4+
5+
from .registry import ConvTemplateRegistry
6+
7+
# Note that eos_token id is "50279" both in Allenai and AMD version.
8+
# So use the number instead of text.
9+
# Allenai version chat_template and eos_token:
10+
# https://huggingface.co/allenai/OLMo-7B-Instruct/blob/main/tokenizer_config.json
11+
# AMD version chat_template and eos_token:
12+
# https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO/blob/main/tokenizer_config.json
13+
ConvTemplateRegistry.register_conv_template(
14+
Conversation(
15+
name="olmo",
16+
system_template=f"{MessagePlaceholders.SYSTEM.value}",
17+
system_message="",
18+
system_prefix_token_ids=[50279],
19+
roles={
20+
"user": "<|user|>",
21+
"assistant": "<|assistant|>",
22+
},
23+
seps=["\n"],
24+
role_content_sep="\n",
25+
role_empty_sep="\n",
26+
stop_token_ids=[50279],
27+
)
28+
)

python/mlc_llm/interface/gen_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,4 +306,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
306306
"aya-23",
307307
"deepseek_v2",
308308
"deepseek",
309+
"olmo",
309310
}

python/mlc_llm/model/model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .minicpm import minicpm_loader, minicpm_model, minicpm_quantization
2929
from .mistral import mistral_loader, mistral_model, mistral_quantization
3030
from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization
31+
from .olmo import olmo_loader, olmo_model, olmo_quantization
3132
from .orion import orion_loader, orion_model, orion_quantization
3233
from .phi import phi_loader, phi_model, phi_quantization
3334
from .phi3 import phi3_loader, phi3_model, phi3_quantization
@@ -532,4 +533,21 @@ class Model:
532533
"ft-quant": deepseek_quantization.ft_quant,
533534
},
534535
),
536+
"olmo": Model(
537+
name="olmo",
538+
model=olmo_model.OLMoForCausalLM,
539+
config=olmo_model.OLMoConfig,
540+
source={
541+
"huggingface-torch": olmo_loader.huggingface,
542+
"huggingface-safetensor": olmo_loader.huggingface,
543+
"awq": olmo_loader.awq,
544+
},
545+
quantize={
546+
"no-quant": olmo_quantization.no_quant,
547+
"group-quant": olmo_quantization.group_quant,
548+
"ft-quant": olmo_quantization.ft_quant,
549+
"awq": olmo_quantization.awq_quant,
550+
"per-tensor-quant": olmo_quantization.per_tensor_quant,
551+
},
552+
),
535553
}

python/mlc_llm/model/olmo/__init__.py

Whitespace-only changes.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
This file specifies how MLC's OLMo parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
import numpy as np
9+
10+
from mlc_llm.loader import ExternMapping
11+
from mlc_llm.quantization import Quantization
12+
13+
from .olmo_model import OLMoConfig, OLMoForCausalLM
14+
from .olmo_quantization import awq_quant
15+
16+
17+
def huggingface(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping:
18+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
19+
the names of HuggingFace PyTorch parameters.
20+
21+
Parameters
22+
----------
23+
model_config : OLMoConfig
24+
The configuration of the OLMo model.
25+
26+
quantization : Quantization
27+
The quantization configuration.
28+
29+
Returns
30+
-------
31+
param_map : ExternMapping
32+
The parameter mapping from MLC to HuggingFace PyTorch.
33+
"""
34+
model = OLMoForCausalLM(model_config)
35+
if quantization is not None:
36+
model.to(quantization.model_dtype)
37+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
38+
spec=model.get_default_spec(),
39+
allow_extern=True,
40+
)
41+
named_parameters = dict(_named_params)
42+
43+
mapping = ExternMapping()
44+
45+
for i in range(model_config.num_hidden_layers):
46+
# Add QKV in self attention
47+
attn = f"model.layers.{i}.self_attn"
48+
mlc_name = f"{attn}.qkv_proj.weight"
49+
mlc_param = named_parameters[mlc_name]
50+
mapping.add_mapping(
51+
mlc_name,
52+
[
53+
f"{attn}.q_proj.weight",
54+
f"{attn}.k_proj.weight",
55+
f"{attn}.v_proj.weight",
56+
],
57+
functools.partial(
58+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
59+
dtype=mlc_param.dtype,
60+
),
61+
)
62+
# Add gates in MLP
63+
mlp = f"model.layers.{i}.mlp"
64+
mlc_name = f"{mlp}.gate_up_proj.weight"
65+
mlc_param = named_parameters[mlc_name]
66+
mapping.add_mapping(
67+
mlc_name,
68+
[
69+
f"{mlp}.gate_proj.weight",
70+
f"{mlp}.up_proj.weight",
71+
],
72+
functools.partial(
73+
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
74+
dtype=mlc_param.dtype,
75+
),
76+
)
77+
# inv_freq is not used in the model
78+
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")
79+
80+
for mlc_name, mlc_param in named_parameters.items():
81+
if mlc_name not in mapping.param_map:
82+
mapping.add_mapping(
83+
mlc_name,
84+
[mlc_name],
85+
functools.partial(
86+
lambda x, dtype: x.astype(dtype),
87+
dtype=mlc_param.dtype,
88+
),
89+
)
90+
return mapping
91+
92+
93+
def awq(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping:
94+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
95+
the names of AWQ parameters.
96+
Parameters
97+
----------
98+
model_config : OLMoConfig
99+
The configuration of the OLMo model.
100+
101+
quantization : Quantization
102+
The quantization configuration.
103+
104+
Returns
105+
-------
106+
param_map : ExternMapping
107+
The parameter mapping from MLC to AWQ.
108+
"""
109+
model, _ = awq_quant(model_config, quantization)
110+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
111+
spec=model.get_default_spec(), # type: ignore[attr-defined]
112+
allow_extern=True,
113+
)
114+
named_parameters = dict(_named_params)
115+
116+
mapping = ExternMapping()
117+
118+
for i in range(model_config.num_hidden_layers):
119+
# Add QKV in self attention
120+
attn = f"model.layers.{i}.self_attn"
121+
for quantize_suffix in ["qweight", "qzeros", "scales"]:
122+
mlc_name = f"{attn}.qkv_proj.{quantize_suffix}"
123+
assert mlc_name in named_parameters
124+
mlc_param = named_parameters[mlc_name]
125+
mapping.add_mapping(
126+
mlc_name,
127+
[
128+
f"{attn}.q_proj.{quantize_suffix}",
129+
f"{attn}.k_proj.{quantize_suffix}",
130+
f"{attn}.v_proj.{quantize_suffix}",
131+
],
132+
functools.partial(
133+
lambda q, k, v, dtype: np.concatenate(
134+
[q, k, v],
135+
axis=1, # AWQ GEMM would transpose the weight
136+
).astype(dtype),
137+
dtype=mlc_param.dtype,
138+
),
139+
)
140+
141+
# Concat gate and up in MLP
142+
mlp = f"model.layers.{i}.mlp"
143+
for quantize_suffix in ["qweight", "qzeros", "scales"]:
144+
mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}"
145+
assert mlc_name in named_parameters
146+
mlc_param = named_parameters[mlc_name]
147+
mapping.add_mapping(
148+
mlc_name,
149+
[
150+
f"{mlp}.gate_proj.{quantize_suffix}",
151+
f"{mlp}.up_proj.{quantize_suffix}",
152+
],
153+
functools.partial(
154+
lambda gate, up, dtype: np.concatenate(
155+
[gate, up],
156+
axis=1, # AWQ GEMM would transpose the weight
157+
).astype(dtype),
158+
dtype=mlc_param.dtype,
159+
),
160+
)
161+
162+
# inv_freq is not used in the model
163+
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")
164+
165+
for mlc_name, mlc_param in named_parameters.items():
166+
if mlc_name not in mapping.param_map:
167+
mapping.add_mapping(
168+
mlc_name,
169+
[mlc_name],
170+
functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),
171+
)
172+
return mapping

0 commit comments

Comments
 (0)