Skip to content

Add ArceeForCausalLM support #3294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
153 changes: 153 additions & 0 deletions python/mlc_llm/model/arcee/arcee_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
This file specifies how MLC's Arcee parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_llm.loader import ExternMapping
from mlc_llm.quantization import Quantization

from .arcee_model import ArceeConfig, ArceeForCausalLM
from .arcee_quantization import awq_quant


def huggingface(model_config: ArceeConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.

Parameters
----------
model_config : ArceeConfig
The configuration of the Arcee model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = ArceeForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
mlc_name = f"{attn}.qkv_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# Handle biases if present
if model_config.attention_bias:
mlc_bias_name = f"{attn}.qkv_proj.bias"
if mlc_bias_name in named_parameters:
mlc_param = named_parameters[mlc_bias_name]
mapping.add_mapping(
mlc_bias_name,
[
f"{attn}.q_proj.bias",
f"{attn}.k_proj.bias",
f"{attn}.v_proj.bias",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

# Note: Arcee MLP doesn't use gate projection, so no concatenation needed for MLP
# The up_proj and down_proj map directly

# inv_freq is not used in the model
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),
)
return mapping


def awq(model_config: ArceeConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of AWQ parameters.

Parameters
----------
model_config : ArceeConfig
The configuration of the Arcee model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to AWQ.
"""
model, _ = awq_quant(model_config, quantization)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(), # type: ignore[attr-defined]
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"

for quantize_suffix in ["qweight", "qzeros", "scales"]:
mlc_name = f"{attn}.qkv_proj.{quantize_suffix}"
assert mlc_name in named_parameters
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.{quantize_suffix}",
f"{attn}.k_proj.{quantize_suffix}",
f"{attn}.v_proj.{quantize_suffix}",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

# inv_freq is not used in the model
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),
)
return mapping
Loading