1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
"""Inference-only MiniMaxText01 model."""
4
- import copy
5
4
import math
6
5
from collections .abc import Iterable
7
6
from typing import TYPE_CHECKING , Optional , Union
19
18
20
19
from vllm import envs
21
20
from vllm .attention import Attention , AttentionMetadata
21
+ from vllm .compilation .decorators import support_torch_compile
22
22
from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
23
23
get_current_vllm_config )
24
24
from vllm .distributed .communication_op import tensor_model_parallel_all_reduce
25
25
from vllm .distributed .parallel_state import (
26
26
get_pp_group , get_tensor_model_parallel_rank ,
27
27
get_tensor_model_parallel_world_size )
28
- from vllm .forward_context import get_forward_context
28
+ from vllm .forward_context import ForwardContext , get_forward_context
29
29
from vllm .model_executor .custom_op import CustomOp
30
30
from vllm .model_executor .layers .activation import SiluAndMul
31
31
from vllm .model_executor .layers .fused_moe import FusedMoE
43
43
MambaStateDtypeCalculator , MambaStateShapeCalculator )
44
44
from vllm .model_executor .layers .quantization .base_config import (
45
45
QuantizationConfig )
46
+ from vllm .model_executor .layers .rotary_embedding import get_rope
46
47
from vllm .model_executor .layers .vocab_parallel_embedding import (
47
48
DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , VocabParallelEmbedding )
48
49
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
49
50
from vllm .model_executor .models .utils import maybe_prefix
50
51
from vllm .model_executor .sampling_metadata import SamplingMetadata
52
+ from vllm .platforms import current_platform
51
53
from vllm .sequence import IntermediateTensors
54
+ from vllm .utils import direct_register_custom_op
52
55
from vllm .v1 .attention .backends .linear_attn import LinearAttentionMetadata
53
56
54
57
from .interfaces import HasInnerState , IsHybrid
@@ -143,61 +146,6 @@ def forward(
143
146
return self ._forward (x )
144
147
145
148
146
- class MiniMaxText01RotaryEmbedding (CustomOp ):
147
- name = "MiniMaxText01RotaryEmbedding"
148
-
149
- def __init__ (
150
- self ,
151
- head_size : int ,
152
- rotary_dim : int ,
153
- max_position : int ,
154
- base : float ,
155
- is_neox_style : bool ,
156
- cache_dtype : torch .dtype ,
157
- ) -> None :
158
- super ().__init__ ()
159
- self .head_size = head_size
160
- self .rotary_dim = rotary_dim
161
- self .max_position_embeddings = max_position
162
- self .base = base
163
- self .is_neox_style = is_neox_style
164
- self .cache_dtype = cache_dtype
165
- cache = self ._compute_cos_sin_cache ().to (cache_dtype )
166
- self .register_buffer ("cos_sin_cache" , cache , persistent = False )
167
-
168
- def _compute_inv_freq (self , base : float ) -> torch .Tensor :
169
- """Compute the inverse frequency."""
170
- inv_freq = 1.0 / (base ** (torch .arange (
171
- 0 , self .rotary_dim , 2 , dtype = torch .float ) / self .rotary_dim ))
172
- return inv_freq
173
-
174
- def _compute_cos_sin_cache (self ) -> torch .Tensor :
175
- """Compute the cos and sin cache."""
176
- inv_freq = self ._compute_inv_freq (self .base )
177
- t = torch .arange (self .max_position_embeddings , dtype = torch .float )
178
- freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
179
- cos = freqs .cos ()
180
- sin = freqs .sin ()
181
- cache = torch .cat ((cos , sin ), dim = - 1 )
182
- return cache
183
-
184
- def forward (
185
- self ,
186
- positions : torch .Tensor ,
187
- query : torch .Tensor ,
188
- key : torch .Tensor ,
189
- ) -> tuple [torch .Tensor , torch .Tensor ]:
190
- from vllm import _custom_ops as ops
191
- self .cos_sin_cache = self .cos_sin_cache .to (positions .device )
192
- query_cast = query .to (self .cache_dtype )
193
- key_cast = key .to (self .cache_dtype )
194
- ops .rotary_embedding (positions , query_cast , key_cast , self .head_size ,
195
- self .cos_sin_cache , self .is_neox_style )
196
- query = query_cast .to (query .dtype )
197
- key = key_cast .to (key .dtype )
198
- return query , key
199
-
200
-
201
149
class MiniMaxText01MLP (nn .Module ):
202
150
203
151
def __init__ (
@@ -526,20 +474,40 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
526
474
slot_id , 32 )
527
475
return hidden
528
476
529
- def forward (self , hidden_states : torch .Tensor , positions : torch .Tensor ,
530
- kv_caches : MinimaxCacheParams , ** kwargs ) -> torch .Tensor :
531
- qkv , _ = self .qkv_proj (hidden_states )
477
+ def forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
478
+ positions : torch .Tensor ,
479
+ kv_caches : MinimaxCacheParams ) -> None :
480
+ if not envs .VLLM_USE_V1 :
481
+ self ._forward (hidden_states , output , positions , kv_caches )
482
+ else :
483
+ torch .ops .vllm .linear_attention (
484
+ hidden_states ,
485
+ output ,
486
+ positions ,
487
+ self .prefix ,
488
+ )
489
+
490
+ def _forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
491
+ positions : torch .Tensor ,
492
+ kv_caches : Optional [MinimaxCacheParams ]) -> None :
493
+ forward_context = get_forward_context ()
494
+ attn_metadata : AttentionMetadata = forward_context .attn_metadata
495
+ if envs .VLLM_USE_V1 and attn_metadata is not None :
496
+ assert isinstance (attn_metadata , dict )
497
+ attn_metadata = attn_metadata [self .prefix ]
498
+ assert isinstance (attn_metadata , LinearAttentionMetadata )
499
+ num_actual_tokens = attn_metadata .num_prefill_tokens + \
500
+ attn_metadata .num_decode_tokens
501
+ else :
502
+ num_actual_tokens = hidden_states .shape [0 ]
503
+
504
+ qkv , _ = self .qkv_proj (hidden_states [:num_actual_tokens ])
532
505
qkv32 = qkv .to (torch .float32 )
533
506
qkvact = torch .nn .functional .silu (qkv32 )
534
507
qkvact = qkvact .view ((qkv .shape [0 ], self .tp_heads , - 1 ))
535
508
q , k , v = torch .split (qkvact , [self .head_dim ] * 3 , dim = - 1 )
536
- forward_context = get_forward_context ()
537
- attn_metadata = forward_context .attn_metadata
538
509
if envs .VLLM_USE_V1 :
539
510
if attn_metadata is not None :
540
- assert isinstance (attn_metadata , dict )
541
- attn_metadata = attn_metadata [self .prefix ]
542
- assert isinstance (attn_metadata , LinearAttentionMetadata )
543
511
kv_cache = self .kv_cache [forward_context .virtual_engine ][0 ]
544
512
state_indices_tensor = attn_metadata .state_indices_tensor
545
513
@@ -578,13 +546,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
578
546
hidden = self ._decode_infer (q , k , v , kv_cache ,
579
547
state_indices_tensor ,
580
548
attn_metadata )
581
-
582
549
hidden = self .norm ._forward (hidden )
583
- gate , _ = self .output_gate (hidden_states )
550
+ gate , _ = self .output_gate (hidden_states [: num_actual_tokens ] )
584
551
hidden = F .sigmoid (gate ) * hidden
585
552
hidden = hidden .to (hidden_states .dtype )
586
- hidden , _ = self .out_proj (hidden )
587
- return hidden
553
+ output [:num_actual_tokens ], _ = self .out_proj (hidden )
588
554
589
555
590
556
class MiniMaxText01Attention (nn .Module ):
@@ -652,23 +618,23 @@ def __init__(
652
618
quant_config = quant_config ,
653
619
prefix = f"{ prefix } .attn" ,
654
620
)
621
+ self .rotary_emb = get_rope (
622
+ head_size = self .head_dim ,
623
+ rotary_dim = rotary_dim ,
624
+ max_position = max_position ,
625
+ base = int (rope_theta ),
626
+ is_neox_style = True ,
627
+ dtype = torch .float32 ,
628
+ )
655
629
return
656
630
657
- def forward (self , hidden_states : torch .Tensor , positions : torch .Tensor ,
658
- ** kwargs ) -> torch .Tensor :
659
- forward_context = get_forward_context ()
660
- attn_metadata = forward_context .attn_metadata
631
+ def forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
632
+ positions : torch .Tensor , ** kwargs ) -> None :
661
633
qkv , _ = self .qkv_proj (hidden_states )
662
634
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
663
- if envs .VLLM_USE_V1 :
664
- if attn_metadata is not None :
665
- q , k = attn_metadata [f"{ self .prefix } .attn" ].rotary_emb (
666
- positions , q , k )
667
- else :
668
- q , k = attn_metadata .rotary_emb (positions , q , k )
635
+ q , k = self .rotary_emb (positions , q , k )
669
636
attn_output = self .attn (q , k , v )
670
- output , _ = self .o_proj (attn_output )
671
- return output
637
+ output [:], _ = self .o_proj (attn_output )
672
638
673
639
674
640
class MiniMaxText01DecoderLayer (nn .Module ):
@@ -816,16 +782,15 @@ def forward(self,
816
782
is_warmup : bool = False ,
817
783
** kwargs ) -> tuple [torch .Tensor , torch .Tensor ]:
818
784
819
- forward_context = get_forward_context ()
820
- attn_metadata = forward_context .attn_metadata
821
785
layernorm_input = hidden_states
822
786
layernorm_output = self .input_layernorm (layernorm_input )
823
787
residual = layernorm_output if self .postnorm else layernorm_input
824
- self_attention_output = self .self_attn (
788
+ self_attention_output = torch .empty_like (layernorm_output )
789
+ self .self_attn (
825
790
hidden_states = layernorm_output ,
791
+ output = self_attention_output ,
826
792
positions = positions ,
827
793
kv_caches = kv_caches ,
828
- attn_metadata = attn_metadata ,
829
794
)
830
795
831
796
residual = residual * self .layernorm_attention_alpha
@@ -839,8 +804,8 @@ def forward(self,
839
804
if self .expert_num == 1 :
840
805
hidden_states = self .mlp (layernorm_output )
841
806
else :
842
- moe_hidden_states = self . block_sparse_moe (
843
- copy . deepcopy ( layernorm_output ) )
807
+ moe_layernorm_output = layernorm_output . clone ()
808
+ moe_hidden_states = self . block_sparse_moe ( moe_layernorm_output )
844
809
if self .shared_moe :
845
810
before_moe_dtype = layernorm_output .dtype
846
811
moe_hidden_fp32 = moe_hidden_states .to (torch .float32 )
@@ -878,18 +843,16 @@ def shared_moe_coefficient_loader(param: torch.Tensor,
878
843
return
879
844
880
845
846
+ @support_torch_compile
881
847
class MiniMaxText01Model (nn .Module ):
882
848
883
- def __init__ (
884
- self ,
885
- config : MiniMaxConfig ,
886
- model_config : Optional [ModelConfig ] = None ,
887
- quant_config : Optional [QuantizationConfig ] = None ,
888
- cache_config : Optional [CacheConfig ] = None ,
889
- scheduler_config = None ,
890
- prefix : str = "" ,
891
- ) -> None :
849
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
892
850
super ().__init__ ()
851
+ config : MiniMaxConfig = vllm_config .model_config .hf_config
852
+ model_config = vllm_config .model_config
853
+ quant_config = vllm_config .quant_config
854
+ cache_config = vllm_config .cache_config
855
+ scheduler_config = vllm_config .scheduler_config
893
856
894
857
self .padding_idx = config .pad_token_id
895
858
self .vocab_size = config .vocab_size
@@ -976,24 +939,6 @@ def layer_fn(prefix):
976
939
self .minimax_cache = MinimaxCacheManager (
977
940
dtype = torch .float32 , cache_shape = self .cache_shape )
978
941
979
- rope_theta = getattr (config , "rope_theta" , 10000 )
980
- head_dim = getattr (config , "head_dim" , None )
981
- if head_dim is None :
982
- head_dim = config .hidden_size // config .num_attention_heads
983
- if hasattr (config , "max_model_len" ) and isinstance (
984
- config .max_model_len , int ):
985
- max_position_embeddings = min (config .max_position_embeddings ,
986
- config .max_model_len )
987
- self .rotary_emb = MiniMaxText01RotaryEmbedding (
988
- head_dim ,
989
- rotary_dim = config .rotary_dim
990
- if hasattr (config , "rotary_dim" ) else head_dim ,
991
- max_position = max_position_embeddings ,
992
- base = int (rope_theta ),
993
- is_neox_style = True ,
994
- cache_dtype = torch .float32 ,
995
- )
996
-
997
942
norm_kwargs = {}
998
943
if hasattr (config , "rms_norm_eps" ):
999
944
norm_kwargs ["eps" ] = config .rms_norm_eps
@@ -1043,12 +988,11 @@ def forward(self,
1043
988
attn_metadata = forward_context .attn_metadata
1044
989
if not envs .VLLM_USE_V1 and attn_metadata is None :
1045
990
return None
1046
- if "request_ids_to_seq_ids" not in kwargs :
1047
- kwargs ["request_ids_to_seq_ids" ] = {}
1048
- if "finished_requests_ids" not in kwargs :
1049
- kwargs ["finished_requests_ids" ] = []
1050
-
1051
991
if not envs .VLLM_USE_V1 :
992
+ if "request_ids_to_seq_ids" not in kwargs :
993
+ kwargs ["request_ids_to_seq_ids" ] = {}
994
+ if "finished_requests_ids" not in kwargs :
995
+ kwargs ["finished_requests_ids" ] = []
1052
996
(
1053
997
minimax_cache_tensors ,
1054
998
state_indices_tensor ,
@@ -1077,16 +1021,6 @@ def forward(self,
1077
1021
1078
1022
for i in range (self .start_layer , self .end_layer ):
1079
1023
layer = self .layers [i ]
1080
- if attn_metadata is not None :
1081
- # TODO (tdoublep): this whole thing with the rotary_emb is
1082
- # weird. we shouldn't be passing it via attn_metadata imo.
1083
- if envs .VLLM_USE_V1 :
1084
- if isinstance (layer .self_attn , MiniMaxText01Attention ):
1085
- attn_metadata [layer .prefix +
1086
- ".attn" ].rotary_emb = self .rotary_emb
1087
- else :
1088
- attn_metadata .rotary_emb = self .rotary_emb
1089
-
1090
1024
_caches = None
1091
1025
if not envs .VLLM_USE_V1 and isinstance (
1092
1026
layer .self_attn , MiniMaxText01LinearAttention ):
@@ -1120,7 +1054,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1120
1054
1121
1055
super ().__init__ ()
1122
1056
config = vllm_config .model_config .hf_config
1123
- quant_config = vllm_config .quant_config
1124
1057
lora_config = vllm_config .lora_config
1125
1058
self .config = config
1126
1059
self .lora_config = lora_config
@@ -1133,13 +1066,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1133
1066
self .unpadded_vocab_size = self .config .vocab_size
1134
1067
if hasattr (vllm_config .model_config , "max_model_len" ):
1135
1068
self .config .max_model_len = vllm_config .model_config .max_model_len
1136
- self .model = MiniMaxText01Model (
1137
- self .config ,
1138
- model_config = vllm_config .model_config ,
1139
- cache_config = vllm_config .cache_config ,
1140
- quant_config = quant_config ,
1141
- scheduler_config = vllm_config .scheduler_config ,
1142
- prefix = maybe_prefix (prefix , "model" ))
1069
+ self .model = MiniMaxText01Model (vllm_config = vllm_config ,
1070
+ prefix = maybe_prefix (prefix , "model" ))
1143
1071
if get_pp_group ().is_last_rank :
1144
1072
self .lm_head = ParallelLMHead (
1145
1073
self .unpadded_vocab_size ,
@@ -1469,3 +1397,35 @@ def get_mamba_state_shape_from_config(
1469
1397
tp_size = parallel_config .tensor_parallel_size ,
1470
1398
head_dim = hf_config .head_dim ,
1471
1399
)
1400
+
1401
+
1402
+ def linear_attention (
1403
+ hidden_states : torch .Tensor ,
1404
+ output : torch .Tensor ,
1405
+ positions : torch .Tensor ,
1406
+ layer_name : str ,
1407
+ ) -> None :
1408
+ forward_context : ForwardContext = get_forward_context ()
1409
+ self = forward_context .no_compile_layers [layer_name ]
1410
+ self ._forward (hidden_states = hidden_states ,
1411
+ output = output ,
1412
+ positions = positions ,
1413
+ kv_caches = None )
1414
+
1415
+
1416
+ def linear_attention_fake (
1417
+ hidden_states : torch .Tensor ,
1418
+ output : torch .Tensor ,
1419
+ positions : torch .Tensor ,
1420
+ layer_name : str ,
1421
+ ) -> None :
1422
+ return
1423
+
1424
+
1425
+ direct_register_custom_op (
1426
+ op_name = "linear_attention" ,
1427
+ op_func = linear_attention ,
1428
+ mutates_args = ["output" ],
1429
+ fake_impl = linear_attention_fake ,
1430
+ dispatch_key = current_platform .dispatch_key ,
1431
+ )
0 commit comments