Skip to content

Commit 9f7f73d

Browse files
alyosha-swamyjeejeelee
authored andcommitted
Add arcee model (vllm-project#21296)
Signed-off-by: alyosha-swamy <[email protected]> Signed-off-by: Jee Jee Li <[email protected]> Co-authored-by: Jee Jee Li <[email protected]> Signed-off-by: avigny <[email protected]>
1 parent 1d3342b commit 9f7f73d

File tree

4 files changed

+351
-0
lines changed

4 files changed

+351
-0
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ th {
324324
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
325325
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
326326
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
327+
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ |
327328
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
328329
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
329330
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def check_available_online(
135135
trust_remote_code=True),
136136
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
137137
trust_remote_code=True),
138+
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base",
139+
is_available_online=False),
138140
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
139141
trust_remote_code=True),
140142
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",

vllm/model_executor/models/arcee.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Copyright 2023-2025 vLLM Team
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# You may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Inference-only Arcee (AFM) model – adds support for ReLU^2 feed-forward
9+
# activation.
10+
11+
from collections.abc import Iterable
12+
from typing import Any, Optional, Union
13+
14+
import torch
15+
from torch import nn
16+
from transformers import LlamaConfig
17+
18+
from vllm.compilation.decorators import support_torch_compile
19+
from vllm.distributed import get_pp_group
20+
from vllm.model_executor.layers.activation import ReLUSquaredActivation
21+
from vllm.model_executor.layers.layernorm import RMSNorm
22+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
23+
RowParallelLinear)
24+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25+
from vllm.model_executor.layers.vocab_parallel_embedding import (
26+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
27+
from vllm.sequence import IntermediateTensors
28+
29+
from .interfaces import SupportsLoRA, SupportsPP
30+
from .utils import (AutoWeightsLoader, PPMissingLayer,
31+
make_empty_intermediate_tensors_factory, make_layers)
32+
33+
34+
class ArceeMLP(nn.Module):
35+
"""Feed-forward layer for Arcee using ReLU^2 activation
36+
(no gating as in LLaMA)."""
37+
38+
def __init__(self,
39+
hidden_size: int,
40+
intermediate_size: int,
41+
hidden_act: str,
42+
quant_config: Optional[Any] = None,
43+
bias: bool = False,
44+
prefix: str = "",
45+
reduce_results: bool = True) -> None:
46+
super().__init__()
47+
# Single linear projection up to intermediate size
48+
# (no separate gate projection)
49+
self.up_proj = ColumnParallelLinear(
50+
input_size=hidden_size,
51+
output_size=intermediate_size,
52+
bias=bias,
53+
quant_config=quant_config,
54+
prefix=f"{prefix}.up_proj",
55+
)
56+
# Down projection back to hidden size
57+
self.down_proj = RowParallelLinear(
58+
input_size=intermediate_size,
59+
output_size=hidden_size,
60+
bias=bias,
61+
quant_config=quant_config,
62+
reduce_results=reduce_results,
63+
prefix=f"{prefix}.down_proj",
64+
)
65+
if hidden_act != "relu2":
66+
raise ValueError(f"Unsupported activation: {hidden_act}. "
67+
"Only 'relu2' is supported for AFM.")
68+
# Define ReLU^2 activation: (ReLU(x))^2 elementwise
69+
self.act_fn = ReLUSquaredActivation()
70+
71+
def forward(self, x: torch.Tensor) -> torch.Tensor:
72+
x, _ = self.up_proj(x) # Project to intermediate size
73+
x = self.act_fn(x) # Apply ReLU^2 activation elementwise
74+
x, _ = self.down_proj(x) # Project back down to hidden size
75+
return x
76+
77+
78+
class ArceeDecoderLayer(nn.Module):
79+
"""Transformer decoder block for Arcee, with self-attention and
80+
ReLU^2 MLP."""
81+
82+
def __init__(self,
83+
config: LlamaConfig,
84+
cache_config: Optional[Any] = None,
85+
quant_config: Optional[Any] = None,
86+
prefix: str = "") -> None:
87+
super().__init__()
88+
self.hidden_size = config.hidden_size
89+
# Rotary embedding parameters (reuse LLaMA defaults)
90+
rope_theta = getattr(config, "rope_theta", 10000)
91+
rope_scaling = getattr(config, "rope_scaling", None)
92+
if rope_scaling is not None and getattr(
93+
config, "original_max_position_embeddings", None):
94+
rope_scaling["original_max_position_embeddings"] = (
95+
config.original_max_position_embeddings)
96+
max_position_embeddings = getattr(config, "max_position_embeddings",
97+
8192)
98+
# Determine if attention bias is needed (some variants use bias terms)
99+
attention_bias = getattr(config, "attention_bias", False) or getattr(
100+
config, "bias", False)
101+
bias_o_proj = attention_bias
102+
if hasattr(config, "qkv_bias"):
103+
attention_bias = config.qkv_bias
104+
105+
# Self-Attention (using LLaMA's attention structure)
106+
from vllm.model_executor.models.llama import (
107+
LlamaAttention) # import here to avoid circular import
108+
self.self_attn = LlamaAttention(
109+
config=config,
110+
hidden_size=self.hidden_size,
111+
num_heads=config.num_attention_heads,
112+
num_kv_heads=getattr(config, "num_key_value_heads",
113+
config.num_attention_heads),
114+
rope_theta=rope_theta,
115+
rope_scaling=rope_scaling,
116+
max_position_embeddings=max_position_embeddings,
117+
quant_config=quant_config,
118+
bias=attention_bias,
119+
bias_o_proj=bias_o_proj,
120+
cache_config=cache_config,
121+
prefix=f"{prefix}.self_attn",
122+
attn_type=getattr(
123+
config, "attn_type",
124+
"decoder"), # assume decoder (causal) unless specified
125+
)
126+
# MLP with ReLU^2 activation
127+
self.mlp = ArceeMLP(
128+
hidden_size=self.hidden_size,
129+
intermediate_size=config.intermediate_size,
130+
hidden_act=config.hidden_act,
131+
quant_config=quant_config,
132+
bias=getattr(config, "mlp_bias", False),
133+
prefix=f"{prefix}.mlp",
134+
)
135+
# Layer normalization layers (RMSNorm as in LLaMA)
136+
self.input_layernorm = RMSNorm(config.hidden_size,
137+
eps=config.rms_norm_eps)
138+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
139+
eps=config.rms_norm_eps)
140+
141+
def forward(
142+
self, positions: torch.Tensor, hidden_states: torch.Tensor,
143+
residual: Optional[torch.Tensor]
144+
) -> tuple[torch.Tensor, torch.Tensor]:
145+
# Self-Attention block
146+
if residual is None:
147+
residual = hidden_states
148+
hidden_states = self.input_layernorm(hidden_states)
149+
else:
150+
# Fused residual add + layernorm if supported
151+
hidden_states, residual = self.input_layernorm(
152+
hidden_states, residual)
153+
hidden_states = self.self_attn(positions=positions,
154+
hidden_states=hidden_states)
155+
# Feed-forward block
156+
hidden_states, residual = self.post_attention_layernorm(
157+
hidden_states, residual)
158+
hidden_states = self.mlp(hidden_states)
159+
return hidden_states, residual
160+
161+
162+
@support_torch_compile
163+
class ArceeModel(nn.Module):
164+
"""The transformer model backbone for Arcee (embedding layer + stacked
165+
decoder blocks + final norm)."""
166+
167+
def __init__(self,
168+
*,
169+
vllm_config,
170+
prefix: str = "",
171+
layer_type: type[nn.Module] = ArceeDecoderLayer) -> None:
172+
super().__init__()
173+
config: LlamaConfig = vllm_config.model_config.hf_config
174+
cache_config = vllm_config.cache_config
175+
quant_config = vllm_config.quant_config
176+
self.quant_config = quant_config
177+
self.config = config
178+
self.vocab_size = config.vocab_size
179+
self.org_vocab_size = config.vocab_size
180+
181+
# Word embeddings (parallelized if using pipeline parallel)
182+
if get_pp_group().is_first_rank or (config.tie_word_embeddings
183+
and get_pp_group().is_last_rank):
184+
self.embed_tokens = VocabParallelEmbedding(
185+
self.vocab_size,
186+
config.hidden_size,
187+
org_num_embeddings=config.vocab_size,
188+
quant_config=quant_config,
189+
)
190+
else:
191+
self.embed_tokens = PPMissingLayer(
192+
) # placeholder on non-embedding ranks
193+
194+
# Build decoder layers across pipeline ranks
195+
self.start_layer, self.end_layer, self.layers = make_layers(
196+
config.num_hidden_layers,
197+
lambda prefix: layer_type(config=config,
198+
cache_config=cache_config,
199+
quant_config=quant_config,
200+
prefix=prefix),
201+
prefix=f"{prefix}.layers",
202+
)
203+
# Final RMSNorm on the last pipeline stage
204+
if get_pp_group().is_last_rank:
205+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
206+
else:
207+
self.norm = PPMissingLayer()
208+
209+
# For optional capturing of intermediate hidden states
210+
# (not used by default)
211+
self.aux_hidden_state_layers: tuple[int, ...] = tuple()
212+
213+
# Prepare factory for empty intermediate tensors
214+
# (for pipeline scheduling)
215+
self.make_empty_intermediate_tensors = (
216+
make_empty_intermediate_tensors_factory(
217+
["hidden_states", "residual"], config.hidden_size))
218+
219+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
220+
return self.embed_tokens(input_ids)
221+
222+
def forward(
223+
self,
224+
input_ids: Optional[torch.Tensor],
225+
positions: torch.Tensor,
226+
intermediate_tensors: Optional[IntermediateTensors],
227+
inputs_embeds: Optional[torch.Tensor] = None
228+
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
229+
list[torch.Tensor]]]:
230+
# Embedding lookup (on first pipeline rank)
231+
if get_pp_group().is_first_rank:
232+
hidden_states = (inputs_embeds if inputs_embeds is not None else
233+
self.get_input_embeddings(input_ids))
234+
residual = None
235+
else:
236+
assert intermediate_tensors is not None, (
237+
"IntermediateTensors must be provided for non-first "
238+
"pipeline ranks")
239+
hidden_states = intermediate_tensors["hidden_states"]
240+
residual = intermediate_tensors["residual"]
241+
242+
aux_hidden_states: list[torch.Tensor] = []
243+
for idx, layer in enumerate(
244+
self.layers[self.start_layer:self.end_layer]):
245+
if idx in self.aux_hidden_state_layers:
246+
aux_hidden_states.append(
247+
hidden_states +
248+
residual) # capture pre-layer hidden state if needed
249+
hidden_states, residual = layer(positions, hidden_states, residual)
250+
251+
if not get_pp_group().is_last_rank:
252+
# Send intermediate results to the next pipeline stage
253+
return IntermediateTensors({
254+
"hidden_states": hidden_states,
255+
"residual": residual
256+
})
257+
# On last rank: apply final layer norm
258+
hidden_states, _ = self.norm(hidden_states, residual)
259+
if len(aux_hidden_states) > 0:
260+
return hidden_states, aux_hidden_states
261+
return hidden_states
262+
263+
264+
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
265+
"""Arcee Model for causal language modeling, integrated with vLLM
266+
runtime."""
267+
# Map fused module names to their sub-module components
268+
# (for quantization and LoRA)
269+
packed_modules_mapping = {
270+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
271+
}
272+
273+
def __init__(self, *, vllm_config, prefix: str = "") -> None:
274+
super().__init__()
275+
config = vllm_config.model_config.hf_config
276+
self.config = config
277+
278+
# Initialize the inner Transformer model (ArceeModel)
279+
self.model = ArceeModel(vllm_config=vllm_config,
280+
prefix=f"{prefix}.model")
281+
# On the last pipeline stage, set up the LM head and logits processor
282+
if get_pp_group().is_last_rank:
283+
# Determine vocabulary size (including any LoRA extra tokens
284+
# for padded LM head)
285+
self.unpadded_vocab_size = config.vocab_size
286+
287+
self.lm_head = ParallelLMHead(
288+
self.unpadded_vocab_size,
289+
config.hidden_size,
290+
org_num_embeddings=config.vocab_size,
291+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
292+
quant_config=vllm_config.quant_config,
293+
bias=getattr(config, "lm_head_bias", False),
294+
prefix=f"{prefix}.lm_head",
295+
)
296+
if config.tie_word_embeddings:
297+
# Tie output weights with input embedding matrix
298+
self.lm_head = self.lm_head.tie_weights(
299+
self.model.embed_tokens)
300+
logit_scale = getattr(config, "logit_scale", 1.0)
301+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
302+
config.vocab_size,
303+
logit_scale)
304+
else:
305+
# Placeholder for lm_head on non-last ranks
306+
self.lm_head = PPMissingLayer()
307+
# Provide a reference to the model's method for generating empty
308+
# tensors (used in pipeline parallel schedule)
309+
self.make_empty_intermediate_tensors = (
310+
self.model.make_empty_intermediate_tensors)
311+
312+
def forward(
313+
self,
314+
input_ids: torch.Tensor,
315+
positions: torch.Tensor,
316+
intermediate_tensors: Optional[IntermediateTensors] = None,
317+
inputs_embeds: Optional[torch.Tensor] = None
318+
) -> Union[torch.Tensor, IntermediateTensors]:
319+
# Forward pass through the Arcee model backbone
320+
model_output = self.model(input_ids=input_ids,
321+
positions=positions,
322+
intermediate_tensors=intermediate_tensors,
323+
inputs_embeds=inputs_embeds)
324+
return model_output
325+
326+
def compute_logits(self, hidden_states: torch.Tensor,
327+
sampling_metadata) -> Optional[torch.Tensor]:
328+
# Compute final logits from hidden states (last pipeline rank only)
329+
logits = self.logits_processor(self.lm_head, hidden_states,
330+
sampling_metadata)
331+
return logits
332+
333+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
334+
return self.model.get_input_embeddings(input_ids)
335+
336+
def load_weights(self, weights: Iterable[tuple[str,
337+
torch.Tensor]]) -> set[str]:
338+
"""Load weights into the model (delegates to inner model and handles
339+
tied embeddings)."""
340+
loader = AutoWeightsLoader(
341+
self,
342+
skip_prefixes=(["lm_head."]
343+
if self.config.tie_word_embeddings else None),
344+
skip_substrs=["gate_proj"])
345+
# AutoWeightLoader handles weight name remapping, including fusing
346+
# separate q_proj, k_proj, v_proj into qkv_proj
347+
return loader.load_weights(weights)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# [Decoder-only]
3434
"AquilaModel": ("llama", "LlamaForCausalLM"),
3535
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
36+
"ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
3637
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
3738
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
3839
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),

0 commit comments

Comments
 (0)