Skip to content

Commit beae0d7

Browse files
tdoublepepwalsh
authored andcommitted
[V1] [Hybrid] Enable compile and piecewise CUDA graph for MiniMax-Text models (vllm-project#22589)
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 3f67f3e commit beae0d7

File tree

2 files changed

+98
-137
lines changed

2 files changed

+98
-137
lines changed

vllm/config/compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ class CompilationConfig:
339339
"vllm.mamba_mixer2",
340340
"vllm.mamba_mixer",
341341
"vllm.short_conv",
342+
"vllm.linear_attention",
342343
]
343344

344345
def compute_hash(self) -> str:

vllm/model_executor/models/minimax_text_01.py

Lines changed: 97 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Inference-only MiniMaxText01 model."""
4-
import copy
54
import math
65
from collections.abc import Iterable
76
from typing import TYPE_CHECKING, Optional, Union
@@ -19,13 +18,14 @@
1918

2019
from vllm import envs
2120
from vllm.attention import Attention, AttentionMetadata
21+
from vllm.compilation.decorators import support_torch_compile
2222
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
2323
get_current_vllm_config)
2424
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
2525
from vllm.distributed.parallel_state import (
2626
get_pp_group, get_tensor_model_parallel_rank,
2727
get_tensor_model_parallel_world_size)
28-
from vllm.forward_context import get_forward_context
28+
from vllm.forward_context import ForwardContext, get_forward_context
2929
from vllm.model_executor.custom_op import CustomOp
3030
from vllm.model_executor.layers.activation import SiluAndMul
3131
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -43,12 +43,15 @@
4343
MambaStateDtypeCalculator, MambaStateShapeCalculator)
4444
from vllm.model_executor.layers.quantization.base_config import (
4545
QuantizationConfig)
46+
from vllm.model_executor.layers.rotary_embedding import get_rope
4647
from vllm.model_executor.layers.vocab_parallel_embedding import (
4748
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
4849
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
4950
from vllm.model_executor.models.utils import maybe_prefix
5051
from vllm.model_executor.sampling_metadata import SamplingMetadata
52+
from vllm.platforms import current_platform
5153
from vllm.sequence import IntermediateTensors
54+
from vllm.utils import direct_register_custom_op
5255
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
5356

5457
from .interfaces import HasInnerState, IsHybrid
@@ -143,61 +146,6 @@ def forward(
143146
return self._forward(x)
144147

145148

146-
class MiniMaxText01RotaryEmbedding(CustomOp):
147-
name = "MiniMaxText01RotaryEmbedding"
148-
149-
def __init__(
150-
self,
151-
head_size: int,
152-
rotary_dim: int,
153-
max_position: int,
154-
base: float,
155-
is_neox_style: bool,
156-
cache_dtype: torch.dtype,
157-
) -> None:
158-
super().__init__()
159-
self.head_size = head_size
160-
self.rotary_dim = rotary_dim
161-
self.max_position_embeddings = max_position
162-
self.base = base
163-
self.is_neox_style = is_neox_style
164-
self.cache_dtype = cache_dtype
165-
cache = self._compute_cos_sin_cache().to(cache_dtype)
166-
self.register_buffer("cos_sin_cache", cache, persistent=False)
167-
168-
def _compute_inv_freq(self, base: float) -> torch.Tensor:
169-
"""Compute the inverse frequency."""
170-
inv_freq = 1.0 / (base**(torch.arange(
171-
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
172-
return inv_freq
173-
174-
def _compute_cos_sin_cache(self) -> torch.Tensor:
175-
"""Compute the cos and sin cache."""
176-
inv_freq = self._compute_inv_freq(self.base)
177-
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
178-
freqs = torch.einsum("i,j -> ij", t, inv_freq)
179-
cos = freqs.cos()
180-
sin = freqs.sin()
181-
cache = torch.cat((cos, sin), dim=-1)
182-
return cache
183-
184-
def forward(
185-
self,
186-
positions: torch.Tensor,
187-
query: torch.Tensor,
188-
key: torch.Tensor,
189-
) -> tuple[torch.Tensor, torch.Tensor]:
190-
from vllm import _custom_ops as ops
191-
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
192-
query_cast = query.to(self.cache_dtype)
193-
key_cast = key.to(self.cache_dtype)
194-
ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
195-
self.cos_sin_cache, self.is_neox_style)
196-
query = query_cast.to(query.dtype)
197-
key = key_cast.to(key.dtype)
198-
return query, key
199-
200-
201149
class MiniMaxText01MLP(nn.Module):
202150

203151
def __init__(
@@ -526,20 +474,40 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
526474
slot_id, 32)
527475
return hidden
528476

529-
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
530-
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
531-
qkv, _ = self.qkv_proj(hidden_states)
477+
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
478+
positions: torch.Tensor,
479+
kv_caches: MinimaxCacheParams) -> None:
480+
if not envs.VLLM_USE_V1:
481+
self._forward(hidden_states, output, positions, kv_caches)
482+
else:
483+
torch.ops.vllm.linear_attention(
484+
hidden_states,
485+
output,
486+
positions,
487+
self.prefix,
488+
)
489+
490+
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
491+
positions: torch.Tensor,
492+
kv_caches: Optional[MinimaxCacheParams]) -> None:
493+
forward_context = get_forward_context()
494+
attn_metadata: AttentionMetadata = forward_context.attn_metadata
495+
if envs.VLLM_USE_V1 and attn_metadata is not None:
496+
assert isinstance(attn_metadata, dict)
497+
attn_metadata = attn_metadata[self.prefix]
498+
assert isinstance(attn_metadata, LinearAttentionMetadata)
499+
num_actual_tokens = attn_metadata.num_prefill_tokens + \
500+
attn_metadata.num_decode_tokens
501+
else:
502+
num_actual_tokens = hidden_states.shape[0]
503+
504+
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
532505
qkv32 = qkv.to(torch.float32)
533506
qkvact = torch.nn.functional.silu(qkv32)
534507
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
535508
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
536-
forward_context = get_forward_context()
537-
attn_metadata = forward_context.attn_metadata
538509
if envs.VLLM_USE_V1:
539510
if attn_metadata is not None:
540-
assert isinstance(attn_metadata, dict)
541-
attn_metadata = attn_metadata[self.prefix]
542-
assert isinstance(attn_metadata, LinearAttentionMetadata)
543511
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
544512
state_indices_tensor = attn_metadata.state_indices_tensor
545513

@@ -578,13 +546,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
578546
hidden = self._decode_infer(q, k, v, kv_cache,
579547
state_indices_tensor,
580548
attn_metadata)
581-
582549
hidden = self.norm._forward(hidden)
583-
gate, _ = self.output_gate(hidden_states)
550+
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
584551
hidden = F.sigmoid(gate) * hidden
585552
hidden = hidden.to(hidden_states.dtype)
586-
hidden, _ = self.out_proj(hidden)
587-
return hidden
553+
output[:num_actual_tokens], _ = self.out_proj(hidden)
588554

589555

590556
class MiniMaxText01Attention(nn.Module):
@@ -652,23 +618,23 @@ def __init__(
652618
quant_config=quant_config,
653619
prefix=f"{prefix}.attn",
654620
)
621+
self.rotary_emb = get_rope(
622+
head_size=self.head_dim,
623+
rotary_dim=rotary_dim,
624+
max_position=max_position,
625+
base=int(rope_theta),
626+
is_neox_style=True,
627+
dtype=torch.float32,
628+
)
655629
return
656630

657-
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
658-
**kwargs) -> torch.Tensor:
659-
forward_context = get_forward_context()
660-
attn_metadata = forward_context.attn_metadata
631+
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
632+
positions: torch.Tensor, **kwargs) -> None:
661633
qkv, _ = self.qkv_proj(hidden_states)
662634
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
663-
if envs.VLLM_USE_V1:
664-
if attn_metadata is not None:
665-
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
666-
positions, q, k)
667-
else:
668-
q, k = attn_metadata.rotary_emb(positions, q, k)
635+
q, k = self.rotary_emb(positions, q, k)
669636
attn_output = self.attn(q, k, v)
670-
output, _ = self.o_proj(attn_output)
671-
return output
637+
output[:], _ = self.o_proj(attn_output)
672638

673639

674640
class MiniMaxText01DecoderLayer(nn.Module):
@@ -816,16 +782,15 @@ def forward(self,
816782
is_warmup: bool = False,
817783
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
818784

819-
forward_context = get_forward_context()
820-
attn_metadata = forward_context.attn_metadata
821785
layernorm_input = hidden_states
822786
layernorm_output = self.input_layernorm(layernorm_input)
823787
residual = layernorm_output if self.postnorm else layernorm_input
824-
self_attention_output = self.self_attn(
788+
self_attention_output = torch.empty_like(layernorm_output)
789+
self.self_attn(
825790
hidden_states=layernorm_output,
791+
output=self_attention_output,
826792
positions=positions,
827793
kv_caches=kv_caches,
828-
attn_metadata=attn_metadata,
829794
)
830795

831796
residual = residual * self.layernorm_attention_alpha
@@ -839,8 +804,8 @@ def forward(self,
839804
if self.expert_num == 1:
840805
hidden_states = self.mlp(layernorm_output)
841806
else:
842-
moe_hidden_states = self.block_sparse_moe(
843-
copy.deepcopy(layernorm_output))
807+
moe_layernorm_output = layernorm_output.clone()
808+
moe_hidden_states = self.block_sparse_moe(moe_layernorm_output)
844809
if self.shared_moe:
845810
before_moe_dtype = layernorm_output.dtype
846811
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
@@ -878,18 +843,16 @@ def shared_moe_coefficient_loader(param: torch.Tensor,
878843
return
879844

880845

846+
@support_torch_compile
881847
class MiniMaxText01Model(nn.Module):
882848

883-
def __init__(
884-
self,
885-
config: MiniMaxConfig,
886-
model_config: Optional[ModelConfig] = None,
887-
quant_config: Optional[QuantizationConfig] = None,
888-
cache_config: Optional[CacheConfig] = None,
889-
scheduler_config=None,
890-
prefix: str = "",
891-
) -> None:
849+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
892850
super().__init__()
851+
config: MiniMaxConfig = vllm_config.model_config.hf_config
852+
model_config = vllm_config.model_config
853+
quant_config = vllm_config.quant_config
854+
cache_config = vllm_config.cache_config
855+
scheduler_config = vllm_config.scheduler_config
893856

894857
self.padding_idx = config.pad_token_id
895858
self.vocab_size = config.vocab_size
@@ -976,24 +939,6 @@ def layer_fn(prefix):
976939
self.minimax_cache = MinimaxCacheManager(
977940
dtype=torch.float32, cache_shape=self.cache_shape)
978941

979-
rope_theta = getattr(config, "rope_theta", 10000)
980-
head_dim = getattr(config, "head_dim", None)
981-
if head_dim is None:
982-
head_dim = config.hidden_size // config.num_attention_heads
983-
if hasattr(config, "max_model_len") and isinstance(
984-
config.max_model_len, int):
985-
max_position_embeddings = min(config.max_position_embeddings,
986-
config.max_model_len)
987-
self.rotary_emb = MiniMaxText01RotaryEmbedding(
988-
head_dim,
989-
rotary_dim=config.rotary_dim
990-
if hasattr(config, "rotary_dim") else head_dim,
991-
max_position=max_position_embeddings,
992-
base=int(rope_theta),
993-
is_neox_style=True,
994-
cache_dtype=torch.float32,
995-
)
996-
997942
norm_kwargs = {}
998943
if hasattr(config, "rms_norm_eps"):
999944
norm_kwargs["eps"] = config.rms_norm_eps
@@ -1043,12 +988,11 @@ def forward(self,
1043988
attn_metadata = forward_context.attn_metadata
1044989
if not envs.VLLM_USE_V1 and attn_metadata is None:
1045990
return None
1046-
if "request_ids_to_seq_ids" not in kwargs:
1047-
kwargs["request_ids_to_seq_ids"] = {}
1048-
if "finished_requests_ids" not in kwargs:
1049-
kwargs["finished_requests_ids"] = []
1050-
1051991
if not envs.VLLM_USE_V1:
992+
if "request_ids_to_seq_ids" not in kwargs:
993+
kwargs["request_ids_to_seq_ids"] = {}
994+
if "finished_requests_ids" not in kwargs:
995+
kwargs["finished_requests_ids"] = []
1052996
(
1053997
minimax_cache_tensors,
1054998
state_indices_tensor,
@@ -1077,16 +1021,6 @@ def forward(self,
10771021

10781022
for i in range(self.start_layer, self.end_layer):
10791023
layer = self.layers[i]
1080-
if attn_metadata is not None:
1081-
# TODO (tdoublep): this whole thing with the rotary_emb is
1082-
# weird. we shouldn't be passing it via attn_metadata imo.
1083-
if envs.VLLM_USE_V1:
1084-
if isinstance(layer.self_attn, MiniMaxText01Attention):
1085-
attn_metadata[layer.prefix +
1086-
".attn"].rotary_emb = self.rotary_emb
1087-
else:
1088-
attn_metadata.rotary_emb = self.rotary_emb
1089-
10901024
_caches = None
10911025
if not envs.VLLM_USE_V1 and isinstance(
10921026
layer.self_attn, MiniMaxText01LinearAttention):
@@ -1120,7 +1054,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
11201054

11211055
super().__init__()
11221056
config = vllm_config.model_config.hf_config
1123-
quant_config = vllm_config.quant_config
11241057
lora_config = vllm_config.lora_config
11251058
self.config = config
11261059
self.lora_config = lora_config
@@ -1133,13 +1066,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
11331066
self.unpadded_vocab_size = self.config.vocab_size
11341067
if hasattr(vllm_config.model_config, "max_model_len"):
11351068
self.config.max_model_len = vllm_config.model_config.max_model_len
1136-
self.model = MiniMaxText01Model(
1137-
self.config,
1138-
model_config=vllm_config.model_config,
1139-
cache_config=vllm_config.cache_config,
1140-
quant_config=quant_config,
1141-
scheduler_config=vllm_config.scheduler_config,
1142-
prefix=maybe_prefix(prefix, "model"))
1069+
self.model = MiniMaxText01Model(vllm_config=vllm_config,
1070+
prefix=maybe_prefix(prefix, "model"))
11431071
if get_pp_group().is_last_rank:
11441072
self.lm_head = ParallelLMHead(
11451073
self.unpadded_vocab_size,
@@ -1469,3 +1397,35 @@ def get_mamba_state_shape_from_config(
14691397
tp_size=parallel_config.tensor_parallel_size,
14701398
head_dim=hf_config.head_dim,
14711399
)
1400+
1401+
1402+
def linear_attention(
1403+
hidden_states: torch.Tensor,
1404+
output: torch.Tensor,
1405+
positions: torch.Tensor,
1406+
layer_name: str,
1407+
) -> None:
1408+
forward_context: ForwardContext = get_forward_context()
1409+
self = forward_context.no_compile_layers[layer_name]
1410+
self._forward(hidden_states=hidden_states,
1411+
output=output,
1412+
positions=positions,
1413+
kv_caches=None)
1414+
1415+
1416+
def linear_attention_fake(
1417+
hidden_states: torch.Tensor,
1418+
output: torch.Tensor,
1419+
positions: torch.Tensor,
1420+
layer_name: str,
1421+
) -> None:
1422+
return
1423+
1424+
1425+
direct_register_custom_op(
1426+
op_name="linear_attention",
1427+
op_func=linear_attention,
1428+
mutates_args=["output"],
1429+
fake_impl=linear_attention_fake,
1430+
dispatch_key=current_platform.dispatch_key,
1431+
)

0 commit comments

Comments
 (0)