Skip to content

Commit 88ebe6f

Browse files
authored
[Model] Add support for GPTJ architecture (#3012)
This PR supports GPTJ architecture.
1 parent 5c9ebcb commit 88ebe6f

File tree

7 files changed

+543
-10
lines changed

7 files changed

+543
-10
lines changed

python/mlc_llm/model/gpt_j/__init__.py

Whitespace-only changes.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""
2+
This file specifies how MLC's GPTJ parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
import numpy as np
9+
10+
from mlc_llm.loader import ExternMapping
11+
from mlc_llm.quantization import Quantization
12+
13+
from .gpt_j_model import GPTJConfig, GPTJForCausalLM
14+
15+
16+
def huggingface(model_config: GPTJConfig, quantization: Quantization) -> ExternMapping:
17+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
18+
the names of HuggingFace PyTorch parameters.
19+
20+
Parameters
21+
----------
22+
model_config : GPTJConfig
23+
The configuration of the GPTJ model.
24+
25+
quantization : Quantization
26+
The quantization configuration.
27+
28+
Returns
29+
-------
30+
param_map : ExternMapping
31+
The parameter mapping from MLC to HuggingFace PyTorch.
32+
"""
33+
model = GPTJForCausalLM(model_config)
34+
if quantization is not None:
35+
model.to(quantization.model_dtype)
36+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
37+
spec=model.get_default_spec(),
38+
allow_extern=True,
39+
)
40+
named_parameters = dict(_named_params)
41+
42+
mapping = ExternMapping()
43+
44+
for i in range(model_config.n_layer):
45+
# Add gates in MLP
46+
attn = f"transformer.h.{i}.attn"
47+
mlc_name = f"{attn}.c_attn.weight"
48+
mlc_param = named_parameters[mlc_name]
49+
mapping.add_mapping(
50+
mlc_name,
51+
[
52+
f"{attn}.q_proj.weight",
53+
f"{attn}.k_proj.weight",
54+
f"{attn}.v_proj.weight",
55+
],
56+
functools.partial(
57+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
58+
dtype=mlc_param.dtype,
59+
),
60+
)
61+
62+
for mlc_name, mlc_param in named_parameters.items():
63+
if mlc_name not in mapping.param_map:
64+
mapping.add_mapping(
65+
mlc_name,
66+
[mlc_name],
67+
functools.partial(
68+
lambda x, dtype: x.astype(dtype),
69+
dtype=mlc_param.dtype,
70+
),
71+
)
72+
return mapping

0 commit comments

Comments
 (0)