@@ -827,7 +827,8 @@ def _prepare_inputs(
827
827
# Prepare encoder attention metadata separately
828
828
# (encoder layers are not in KV cache groups)
829
829
if self .is_encoder_only_model :
830
- common_attn_metadata , encoder_attn_metadata = \
830
+
831
+ per_layer_metadata = \
831
832
self ._build_encoder_only_attn_metadata (
832
833
scheduler_output )
833
834
@@ -836,6 +837,8 @@ def _prepare_inputs(
836
837
self .vllm_config , Attention )
837
838
for layer_name , attn_module in attention_layers .items ():
838
839
if attn_module .attn_type == AttentionType .ENCODER_ONLY :
840
+ common_attn_metadata , encoder_attn_metadata = \
841
+ per_layer_metadata [layer_name ]
839
842
attn_metadata [layer_name ] = encoder_attn_metadata
840
843
841
844
# Prepare the attention metadata for each KV cache group and make layers
@@ -2684,30 +2687,41 @@ def create_attn_groups(
2684
2687
# Check if model is encoder-only
2685
2688
block_size = self .vllm_config .cache_config .block_size
2686
2689
use_mla = self .vllm_config .model_config .use_mla
2687
- attn_specs = list [AttentionSpec ]( )
2688
- for attn_module in attn_layers .values ():
2690
+ attn_specs : dict [ AttentionSpec , list [str ]] = defaultdict ( list )
2691
+ for layer_name , attn_module in attn_layers .items ():
2689
2692
2690
2693
if attn_module .attn_type == AttentionType .ENCODER_ONLY :
2691
- assert attn_module .sliding_window is None , "Sliding "
2692
- "window attention is not supported for encoder-only models"
2693
-
2694
- attn_specs .append (
2695
- FullAttentionSpec (block_size = block_size ,
2696
- num_kv_heads = attn_module .num_kv_heads ,
2697
- head_size = attn_module .head_size ,
2698
- dtype = self .kv_cache_dtype ,
2699
- use_mla = use_mla ))
2694
+ if attn_module .sliding_window is None :
2695
+ attn_spec : AttentionSpec = FullAttentionSpec (
2696
+ block_size = block_size ,
2697
+ num_kv_heads = attn_module .num_kv_heads ,
2698
+ head_size = attn_module .head_size ,
2699
+ dtype = self .kv_cache_dtype ,
2700
+ use_mla = use_mla )
2701
+ else :
2702
+ attn_spec = SlidingWindowSpec (
2703
+ block_size = block_size ,
2704
+ num_kv_heads = attn_module .num_kv_heads ,
2705
+ head_size = attn_module .head_size ,
2706
+ dtype = self .kv_cache_dtype ,
2707
+ sliding_window = attn_module .sliding_window ,
2708
+ use_mla = use_mla )
2709
+ attn_specs [attn_spec ].append (layer_name )
2710
+
2700
2711
else :
2701
2712
raise ValueError ("Expected only encoder-only layers" )
2702
2713
2703
2714
if len (attn_specs ) > 0 :
2704
- assert len ( attn_specs ) == len ( attn_layers ), \
2705
- "All or none of the layers are expected to be encoder-only"
2715
+ total_layers = 0
2716
+ for attn_spec , layer_names in attn_specs . items ():
2706
2717
2707
- attn_backends = get_attn_backends_for_layers (attn_layers .keys ())
2718
+ attn_backends = get_attn_backends_for_layers (layer_names )
2719
+ total_layers += len (layer_names )
2708
2720
2709
- self .attn_groups .append (
2710
- create_attn_groups (attn_backends , attn_specs [0 ]))
2721
+ self .attn_groups .append (
2722
+ create_attn_groups (attn_backends , attn_spec ))
2723
+ assert total_layers == len (attn_layers ), \
2724
+ "All or none of the layers are expected to be encoder-only"
2711
2725
self .is_encoder_only_model = True
2712
2726
2713
2727
def calculate_reorder_batch_threshold (self ) -> None :
@@ -3080,7 +3094,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
3080
3094
3081
3095
def _build_encoder_only_attn_metadata (
3082
3096
self , scheduler_output : "SchedulerOutput" ) -> \
3083
- tuple [CommonAttentionMetadata , Any ]:
3097
+ dict [ str , tuple [CommonAttentionMetadata , Any ] ]:
3084
3098
"""Prepare encoder attention metadata for encoder-only models.
3085
3099
3086
3100
Args:
@@ -3097,33 +3111,45 @@ def _build_encoder_only_attn_metadata(
3097
3111
tokens = [scheduler_output .num_scheduled_tokens [i ] for i in req_ids ]
3098
3112
max_num_scheduled_tokens = max (tokens )
3099
3113
3100
- # Use the first attention metadata builder
3101
- # to create encoder attention metadata
3102
- builder = self .attn_groups [0 ][0 ].metadata_builder
3103
-
3104
3114
dummy_block_table = torch .zeros ((num_reqs , 1 ),
3105
3115
dtype = torch .int32 ,
3106
3116
device = self .device )
3107
3117
dummy_slot_mapping = torch .zeros ((total_num_scheduled_tokens , ),
3108
3118
dtype = torch .int32 ,
3109
3119
device = self .device )
3110
3120
3111
- common_metadata = CommonAttentionMetadata (
3112
- query_start_loc = self .query_start_loc [:num_reqs + 1 ],
3113
- query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
3114
- seq_lens = self .seq_lens [:num_reqs ],
3115
- seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
3116
- num_computed_tokens_cpu = self .input_batch .
3117
- num_computed_tokens_cpu_tensor [:num_reqs ],
3118
- num_reqs = num_reqs ,
3119
- num_actual_tokens = total_num_scheduled_tokens ,
3120
- max_query_len = max_num_scheduled_tokens ,
3121
- block_table_tensor = dummy_block_table ,
3122
- slot_mapping = dummy_slot_mapping ,
3123
- causal = False ,
3124
- )
3121
+ group_metadata = dict [str , tuple [CommonAttentionMetadata , Any ]]()
3125
3122
3126
- return common_metadata , builder .build (
3127
- common_prefix_len = 0 , # No cascade for encoder
3128
- common_attn_metadata = common_metadata ,
3129
- )
3123
+ for attn_group_list in self .attn_groups :
3124
+
3125
+ assert len (attn_group_list ) == 1
3126
+ attn_group = attn_group_list [0 ]
3127
+
3128
+ # Use the first attention metadata builder
3129
+ # to create encoder attention metadata
3130
+ builder = attn_group .metadata_builder
3131
+
3132
+ common_metadata = CommonAttentionMetadata (
3133
+ query_start_loc = self .query_start_loc [:num_reqs + 1 ],
3134
+ query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
3135
+ seq_lens = self .seq_lens [:num_reqs ],
3136
+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
3137
+ num_computed_tokens_cpu = self .input_batch .
3138
+ num_computed_tokens_cpu_tensor [:num_reqs ],
3139
+ num_reqs = num_reqs ,
3140
+ num_actual_tokens = total_num_scheduled_tokens ,
3141
+ max_query_len = max_num_scheduled_tokens ,
3142
+ block_table_tensor = dummy_block_table ,
3143
+ slot_mapping = dummy_slot_mapping ,
3144
+ causal = False ,
3145
+ )
3146
+
3147
+ metadata = builder .build (
3148
+ common_prefix_len = 0 , # No cascade for encoder
3149
+ common_attn_metadata = common_metadata ,
3150
+ )
3151
+
3152
+ for layer_name in attn_group .layer_names :
3153
+ group_metadata [layer_name ] = (common_metadata , metadata )
3154
+
3155
+ return group_metadata
0 commit comments