|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +# Copyright 2023-2025 vLLM Team |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# You may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Inference-only Arcee (AFM) model – adds support for ReLU^2 feed-forward |
| 9 | +# activation. |
| 10 | + |
| 11 | +from collections.abc import Iterable |
| 12 | +from typing import Any, Optional, Union |
| 13 | + |
| 14 | +import torch |
| 15 | +from torch import nn |
| 16 | +from transformers import LlamaConfig |
| 17 | + |
| 18 | +from vllm.compilation.decorators import support_torch_compile |
| 19 | +from vllm.distributed import get_pp_group |
| 20 | +from vllm.model_executor.layers.activation import ReLUSquaredActivation |
| 21 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 22 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 23 | + RowParallelLinear) |
| 24 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 25 | +from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 26 | + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) |
| 27 | +from vllm.sequence import IntermediateTensors |
| 28 | + |
| 29 | +from .interfaces import SupportsLoRA, SupportsPP |
| 30 | +from .utils import (AutoWeightsLoader, PPMissingLayer, |
| 31 | + make_empty_intermediate_tensors_factory, make_layers) |
| 32 | + |
| 33 | + |
| 34 | +class ArceeMLP(nn.Module): |
| 35 | + """Feed-forward layer for Arcee using ReLU^2 activation |
| 36 | + (no gating as in LLaMA).""" |
| 37 | + |
| 38 | + def __init__(self, |
| 39 | + hidden_size: int, |
| 40 | + intermediate_size: int, |
| 41 | + hidden_act: str, |
| 42 | + quant_config: Optional[Any] = None, |
| 43 | + bias: bool = False, |
| 44 | + prefix: str = "", |
| 45 | + reduce_results: bool = True) -> None: |
| 46 | + super().__init__() |
| 47 | + # Single linear projection up to intermediate size |
| 48 | + # (no separate gate projection) |
| 49 | + self.up_proj = ColumnParallelLinear( |
| 50 | + input_size=hidden_size, |
| 51 | + output_size=intermediate_size, |
| 52 | + bias=bias, |
| 53 | + quant_config=quant_config, |
| 54 | + prefix=f"{prefix}.up_proj", |
| 55 | + ) |
| 56 | + # Down projection back to hidden size |
| 57 | + self.down_proj = RowParallelLinear( |
| 58 | + input_size=intermediate_size, |
| 59 | + output_size=hidden_size, |
| 60 | + bias=bias, |
| 61 | + quant_config=quant_config, |
| 62 | + reduce_results=reduce_results, |
| 63 | + prefix=f"{prefix}.down_proj", |
| 64 | + ) |
| 65 | + if hidden_act != "relu2": |
| 66 | + raise ValueError(f"Unsupported activation: {hidden_act}. " |
| 67 | + "Only 'relu2' is supported for AFM.") |
| 68 | + # Define ReLU^2 activation: (ReLU(x))^2 elementwise |
| 69 | + self.act_fn = ReLUSquaredActivation() |
| 70 | + |
| 71 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 72 | + x, _ = self.up_proj(x) # Project to intermediate size |
| 73 | + x = self.act_fn(x) # Apply ReLU^2 activation elementwise |
| 74 | + x, _ = self.down_proj(x) # Project back down to hidden size |
| 75 | + return x |
| 76 | + |
| 77 | + |
| 78 | +class ArceeDecoderLayer(nn.Module): |
| 79 | + """Transformer decoder block for Arcee, with self-attention and |
| 80 | + ReLU^2 MLP.""" |
| 81 | + |
| 82 | + def __init__(self, |
| 83 | + config: LlamaConfig, |
| 84 | + cache_config: Optional[Any] = None, |
| 85 | + quant_config: Optional[Any] = None, |
| 86 | + prefix: str = "") -> None: |
| 87 | + super().__init__() |
| 88 | + self.hidden_size = config.hidden_size |
| 89 | + # Rotary embedding parameters (reuse LLaMA defaults) |
| 90 | + rope_theta = getattr(config, "rope_theta", 10000) |
| 91 | + rope_scaling = getattr(config, "rope_scaling", None) |
| 92 | + if rope_scaling is not None and getattr( |
| 93 | + config, "original_max_position_embeddings", None): |
| 94 | + rope_scaling["original_max_position_embeddings"] = ( |
| 95 | + config.original_max_position_embeddings) |
| 96 | + max_position_embeddings = getattr(config, "max_position_embeddings", |
| 97 | + 8192) |
| 98 | + # Determine if attention bias is needed (some variants use bias terms) |
| 99 | + attention_bias = getattr(config, "attention_bias", False) or getattr( |
| 100 | + config, "bias", False) |
| 101 | + bias_o_proj = attention_bias |
| 102 | + if hasattr(config, "qkv_bias"): |
| 103 | + attention_bias = config.qkv_bias |
| 104 | + |
| 105 | + # Self-Attention (using LLaMA's attention structure) |
| 106 | + from vllm.model_executor.models.llama import ( |
| 107 | + LlamaAttention) # import here to avoid circular import |
| 108 | + self.self_attn = LlamaAttention( |
| 109 | + config=config, |
| 110 | + hidden_size=self.hidden_size, |
| 111 | + num_heads=config.num_attention_heads, |
| 112 | + num_kv_heads=getattr(config, "num_key_value_heads", |
| 113 | + config.num_attention_heads), |
| 114 | + rope_theta=rope_theta, |
| 115 | + rope_scaling=rope_scaling, |
| 116 | + max_position_embeddings=max_position_embeddings, |
| 117 | + quant_config=quant_config, |
| 118 | + bias=attention_bias, |
| 119 | + bias_o_proj=bias_o_proj, |
| 120 | + cache_config=cache_config, |
| 121 | + prefix=f"{prefix}.self_attn", |
| 122 | + attn_type=getattr( |
| 123 | + config, "attn_type", |
| 124 | + "decoder"), # assume decoder (causal) unless specified |
| 125 | + ) |
| 126 | + # MLP with ReLU^2 activation |
| 127 | + self.mlp = ArceeMLP( |
| 128 | + hidden_size=self.hidden_size, |
| 129 | + intermediate_size=config.intermediate_size, |
| 130 | + hidden_act=config.hidden_act, |
| 131 | + quant_config=quant_config, |
| 132 | + bias=getattr(config, "mlp_bias", False), |
| 133 | + prefix=f"{prefix}.mlp", |
| 134 | + ) |
| 135 | + # Layer normalization layers (RMSNorm as in LLaMA) |
| 136 | + self.input_layernorm = RMSNorm(config.hidden_size, |
| 137 | + eps=config.rms_norm_eps) |
| 138 | + self.post_attention_layernorm = RMSNorm(config.hidden_size, |
| 139 | + eps=config.rms_norm_eps) |
| 140 | + |
| 141 | + def forward( |
| 142 | + self, positions: torch.Tensor, hidden_states: torch.Tensor, |
| 143 | + residual: Optional[torch.Tensor] |
| 144 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 145 | + # Self-Attention block |
| 146 | + if residual is None: |
| 147 | + residual = hidden_states |
| 148 | + hidden_states = self.input_layernorm(hidden_states) |
| 149 | + else: |
| 150 | + # Fused residual add + layernorm if supported |
| 151 | + hidden_states, residual = self.input_layernorm( |
| 152 | + hidden_states, residual) |
| 153 | + hidden_states = self.self_attn(positions=positions, |
| 154 | + hidden_states=hidden_states) |
| 155 | + # Feed-forward block |
| 156 | + hidden_states, residual = self.post_attention_layernorm( |
| 157 | + hidden_states, residual) |
| 158 | + hidden_states = self.mlp(hidden_states) |
| 159 | + return hidden_states, residual |
| 160 | + |
| 161 | + |
| 162 | +@support_torch_compile |
| 163 | +class ArceeModel(nn.Module): |
| 164 | + """The transformer model backbone for Arcee (embedding layer + stacked |
| 165 | + decoder blocks + final norm).""" |
| 166 | + |
| 167 | + def __init__(self, |
| 168 | + *, |
| 169 | + vllm_config, |
| 170 | + prefix: str = "", |
| 171 | + layer_type: type[nn.Module] = ArceeDecoderLayer) -> None: |
| 172 | + super().__init__() |
| 173 | + config: LlamaConfig = vllm_config.model_config.hf_config |
| 174 | + cache_config = vllm_config.cache_config |
| 175 | + quant_config = vllm_config.quant_config |
| 176 | + self.quant_config = quant_config |
| 177 | + self.config = config |
| 178 | + self.vocab_size = config.vocab_size |
| 179 | + self.org_vocab_size = config.vocab_size |
| 180 | + |
| 181 | + # Word embeddings (parallelized if using pipeline parallel) |
| 182 | + if get_pp_group().is_first_rank or (config.tie_word_embeddings |
| 183 | + and get_pp_group().is_last_rank): |
| 184 | + self.embed_tokens = VocabParallelEmbedding( |
| 185 | + self.vocab_size, |
| 186 | + config.hidden_size, |
| 187 | + org_num_embeddings=config.vocab_size, |
| 188 | + quant_config=quant_config, |
| 189 | + ) |
| 190 | + else: |
| 191 | + self.embed_tokens = PPMissingLayer( |
| 192 | + ) # placeholder on non-embedding ranks |
| 193 | + |
| 194 | + # Build decoder layers across pipeline ranks |
| 195 | + self.start_layer, self.end_layer, self.layers = make_layers( |
| 196 | + config.num_hidden_layers, |
| 197 | + lambda prefix: layer_type(config=config, |
| 198 | + cache_config=cache_config, |
| 199 | + quant_config=quant_config, |
| 200 | + prefix=prefix), |
| 201 | + prefix=f"{prefix}.layers", |
| 202 | + ) |
| 203 | + # Final RMSNorm on the last pipeline stage |
| 204 | + if get_pp_group().is_last_rank: |
| 205 | + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 206 | + else: |
| 207 | + self.norm = PPMissingLayer() |
| 208 | + |
| 209 | + # For optional capturing of intermediate hidden states |
| 210 | + # (not used by default) |
| 211 | + self.aux_hidden_state_layers: tuple[int, ...] = tuple() |
| 212 | + |
| 213 | + # Prepare factory for empty intermediate tensors |
| 214 | + # (for pipeline scheduling) |
| 215 | + self.make_empty_intermediate_tensors = ( |
| 216 | + make_empty_intermediate_tensors_factory( |
| 217 | + ["hidden_states", "residual"], config.hidden_size)) |
| 218 | + |
| 219 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 220 | + return self.embed_tokens(input_ids) |
| 221 | + |
| 222 | + def forward( |
| 223 | + self, |
| 224 | + input_ids: Optional[torch.Tensor], |
| 225 | + positions: torch.Tensor, |
| 226 | + intermediate_tensors: Optional[IntermediateTensors], |
| 227 | + inputs_embeds: Optional[torch.Tensor] = None |
| 228 | + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, |
| 229 | + list[torch.Tensor]]]: |
| 230 | + # Embedding lookup (on first pipeline rank) |
| 231 | + if get_pp_group().is_first_rank: |
| 232 | + hidden_states = (inputs_embeds if inputs_embeds is not None else |
| 233 | + self.get_input_embeddings(input_ids)) |
| 234 | + residual = None |
| 235 | + else: |
| 236 | + assert intermediate_tensors is not None, ( |
| 237 | + "IntermediateTensors must be provided for non-first " |
| 238 | + "pipeline ranks") |
| 239 | + hidden_states = intermediate_tensors["hidden_states"] |
| 240 | + residual = intermediate_tensors["residual"] |
| 241 | + |
| 242 | + aux_hidden_states: list[torch.Tensor] = [] |
| 243 | + for idx, layer in enumerate( |
| 244 | + self.layers[self.start_layer:self.end_layer]): |
| 245 | + if idx in self.aux_hidden_state_layers: |
| 246 | + aux_hidden_states.append( |
| 247 | + hidden_states + |
| 248 | + residual) # capture pre-layer hidden state if needed |
| 249 | + hidden_states, residual = layer(positions, hidden_states, residual) |
| 250 | + |
| 251 | + if not get_pp_group().is_last_rank: |
| 252 | + # Send intermediate results to the next pipeline stage |
| 253 | + return IntermediateTensors({ |
| 254 | + "hidden_states": hidden_states, |
| 255 | + "residual": residual |
| 256 | + }) |
| 257 | + # On last rank: apply final layer norm |
| 258 | + hidden_states, _ = self.norm(hidden_states, residual) |
| 259 | + if len(aux_hidden_states) > 0: |
| 260 | + return hidden_states, aux_hidden_states |
| 261 | + return hidden_states |
| 262 | + |
| 263 | + |
| 264 | +class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): |
| 265 | + """Arcee Model for causal language modeling, integrated with vLLM |
| 266 | + runtime.""" |
| 267 | + # Map fused module names to their sub-module components |
| 268 | + # (for quantization and LoRA) |
| 269 | + packed_modules_mapping = { |
| 270 | + "qkv_proj": ["q_proj", "k_proj", "v_proj"], |
| 271 | + } |
| 272 | + |
| 273 | + def __init__(self, *, vllm_config, prefix: str = "") -> None: |
| 274 | + super().__init__() |
| 275 | + config = vllm_config.model_config.hf_config |
| 276 | + self.config = config |
| 277 | + |
| 278 | + # Initialize the inner Transformer model (ArceeModel) |
| 279 | + self.model = ArceeModel(vllm_config=vllm_config, |
| 280 | + prefix=f"{prefix}.model") |
| 281 | + # On the last pipeline stage, set up the LM head and logits processor |
| 282 | + if get_pp_group().is_last_rank: |
| 283 | + # Determine vocabulary size (including any LoRA extra tokens |
| 284 | + # for padded LM head) |
| 285 | + self.unpadded_vocab_size = config.vocab_size |
| 286 | + |
| 287 | + self.lm_head = ParallelLMHead( |
| 288 | + self.unpadded_vocab_size, |
| 289 | + config.hidden_size, |
| 290 | + org_num_embeddings=config.vocab_size, |
| 291 | + padding_size=DEFAULT_VOCAB_PADDING_SIZE, |
| 292 | + quant_config=vllm_config.quant_config, |
| 293 | + bias=getattr(config, "lm_head_bias", False), |
| 294 | + prefix=f"{prefix}.lm_head", |
| 295 | + ) |
| 296 | + if config.tie_word_embeddings: |
| 297 | + # Tie output weights with input embedding matrix |
| 298 | + self.lm_head = self.lm_head.tie_weights( |
| 299 | + self.model.embed_tokens) |
| 300 | + logit_scale = getattr(config, "logit_scale", 1.0) |
| 301 | + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, |
| 302 | + config.vocab_size, |
| 303 | + logit_scale) |
| 304 | + else: |
| 305 | + # Placeholder for lm_head on non-last ranks |
| 306 | + self.lm_head = PPMissingLayer() |
| 307 | + # Provide a reference to the model's method for generating empty |
| 308 | + # tensors (used in pipeline parallel schedule) |
| 309 | + self.make_empty_intermediate_tensors = ( |
| 310 | + self.model.make_empty_intermediate_tensors) |
| 311 | + |
| 312 | + def forward( |
| 313 | + self, |
| 314 | + input_ids: torch.Tensor, |
| 315 | + positions: torch.Tensor, |
| 316 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 317 | + inputs_embeds: Optional[torch.Tensor] = None |
| 318 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 319 | + # Forward pass through the Arcee model backbone |
| 320 | + model_output = self.model(input_ids=input_ids, |
| 321 | + positions=positions, |
| 322 | + intermediate_tensors=intermediate_tensors, |
| 323 | + inputs_embeds=inputs_embeds) |
| 324 | + return model_output |
| 325 | + |
| 326 | + def compute_logits(self, hidden_states: torch.Tensor, |
| 327 | + sampling_metadata) -> Optional[torch.Tensor]: |
| 328 | + # Compute final logits from hidden states (last pipeline rank only) |
| 329 | + logits = self.logits_processor(self.lm_head, hidden_states, |
| 330 | + sampling_metadata) |
| 331 | + return logits |
| 332 | + |
| 333 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 334 | + return self.model.get_input_embeddings(input_ids) |
| 335 | + |
| 336 | + def load_weights(self, weights: Iterable[tuple[str, |
| 337 | + torch.Tensor]]) -> set[str]: |
| 338 | + """Load weights into the model (delegates to inner model and handles |
| 339 | + tied embeddings).""" |
| 340 | + loader = AutoWeightsLoader( |
| 341 | + self, |
| 342 | + skip_prefixes=(["lm_head."] |
| 343 | + if self.config.tie_word_embeddings else None), |
| 344 | + skip_substrs=["gate_proj"]) |
| 345 | + # AutoWeightLoader handles weight name remapping, including fusing |
| 346 | + # separate q_proj, k_proj, v_proj into qkv_proj |
| 347 | + return loader.load_weights(weights) |
0 commit comments