@@ -826,7 +826,8 @@ def _prepare_inputs(
826
826
# Prepare encoder attention metadata separately
827
827
# (encoder layers are not in KV cache groups)
828
828
if self .is_encoder_only_model :
829
- common_attn_metadata , encoder_attn_metadata = \
829
+
830
+ per_layer_metadata = \
830
831
self ._build_encoder_only_attn_metadata (
831
832
scheduler_output )
832
833
@@ -835,6 +836,8 @@ def _prepare_inputs(
835
836
self .vllm_config , Attention )
836
837
for layer_name , attn_module in attention_layers .items ():
837
838
if attn_module .attn_type == AttentionType .ENCODER_ONLY :
839
+ common_attn_metadata , encoder_attn_metadata = \
840
+ per_layer_metadata [layer_name ]
838
841
attn_metadata [layer_name ] = encoder_attn_metadata
839
842
840
843
# Prepare the attention metadata for each KV cache group and make layers
@@ -2686,30 +2689,41 @@ def create_attn_groups(
2686
2689
# Check if model is encoder-only
2687
2690
block_size = self .vllm_config .cache_config .block_size
2688
2691
use_mla = self .vllm_config .model_config .use_mla
2689
- attn_specs = list [AttentionSpec ]( )
2690
- for attn_module in attn_layers .values ():
2692
+ attn_specs : dict [ AttentionSpec , list [str ]] = defaultdict ( list )
2693
+ for layer_name , attn_module in attn_layers .items ():
2691
2694
2692
2695
if attn_module .attn_type == AttentionType .ENCODER_ONLY :
2693
- assert attn_module .sliding_window is None , "Sliding "
2694
- "window attention is not supported for encoder-only models"
2695
-
2696
- attn_specs .append (
2697
- FullAttentionSpec (block_size = block_size ,
2698
- num_kv_heads = attn_module .num_kv_heads ,
2699
- head_size = attn_module .head_size ,
2700
- dtype = self .kv_cache_dtype ,
2701
- use_mla = use_mla ))
2696
+ if attn_module .sliding_window is None :
2697
+ attn_spec : AttentionSpec = FullAttentionSpec (
2698
+ block_size = block_size ,
2699
+ num_kv_heads = attn_module .num_kv_heads ,
2700
+ head_size = attn_module .head_size ,
2701
+ dtype = self .kv_cache_dtype ,
2702
+ use_mla = use_mla )
2703
+ else :
2704
+ attn_spec = SlidingWindowSpec (
2705
+ block_size = block_size ,
2706
+ num_kv_heads = attn_module .num_kv_heads ,
2707
+ head_size = attn_module .head_size ,
2708
+ dtype = self .kv_cache_dtype ,
2709
+ sliding_window = attn_module .sliding_window ,
2710
+ use_mla = use_mla )
2711
+ attn_specs [attn_spec ].append (layer_name )
2712
+
2702
2713
else :
2703
2714
raise ValueError ("Expected only encoder-only layers" )
2704
2715
2705
2716
if len (attn_specs ) > 0 :
2706
- assert len ( attn_specs ) == len ( attn_layers ), \
2707
- "All or none of the layers are expected to be encoder-only"
2717
+ total_layers = 0
2718
+ for attn_spec , layer_names in attn_specs . items ():
2708
2719
2709
- attn_backends = get_attn_backends_for_layers (attn_layers .keys ())
2720
+ attn_backends = get_attn_backends_for_layers (layer_names )
2721
+ total_layers += len (layer_names )
2710
2722
2711
- self .attn_groups .append (
2712
- create_attn_groups (attn_backends , attn_specs [0 ]))
2723
+ self .attn_groups .append (
2724
+ create_attn_groups (attn_backends , attn_spec ))
2725
+ assert total_layers == len (attn_layers ), \
2726
+ "All or none of the layers are expected to be encoder-only"
2713
2727
self .is_encoder_only_model = True
2714
2728
2715
2729
def calculate_reorder_batch_threshold (self ) -> None :
@@ -3074,7 +3088,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
3074
3088
3075
3089
def _build_encoder_only_attn_metadata (
3076
3090
self , scheduler_output : "SchedulerOutput" ) -> \
3077
- tuple [CommonAttentionMetadata , Any ]:
3091
+ dict [ str , tuple [CommonAttentionMetadata , Any ] ]:
3078
3092
"""Prepare encoder attention metadata for encoder-only models.
3079
3093
3080
3094
Args:
@@ -3091,33 +3105,45 @@ def _build_encoder_only_attn_metadata(
3091
3105
tokens = [scheduler_output .num_scheduled_tokens [i ] for i in req_ids ]
3092
3106
max_num_scheduled_tokens = max (tokens )
3093
3107
3094
- # Use the first attention metadata builder
3095
- # to create encoder attention metadata
3096
- builder = self .attn_groups [0 ][0 ].metadata_builder
3097
-
3098
3108
dummy_block_table = torch .zeros ((num_reqs , 1 ),
3099
3109
dtype = torch .int32 ,
3100
3110
device = self .device )
3101
3111
dummy_slot_mapping = torch .zeros ((total_num_scheduled_tokens , ),
3102
3112
dtype = torch .int32 ,
3103
3113
device = self .device )
3104
3114
3105
- common_metadata = CommonAttentionMetadata (
3106
- query_start_loc = self .query_start_loc [:num_reqs + 1 ],
3107
- query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
3108
- seq_lens = self .seq_lens [:num_reqs ],
3109
- seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
3110
- num_computed_tokens_cpu = self .input_batch .
3111
- num_computed_tokens_cpu_tensor [:num_reqs ],
3112
- num_reqs = num_reqs ,
3113
- num_actual_tokens = total_num_scheduled_tokens ,
3114
- max_query_len = max_num_scheduled_tokens ,
3115
- block_table_tensor = dummy_block_table ,
3116
- slot_mapping = dummy_slot_mapping ,
3117
- causal = False ,
3118
- )
3115
+ group_metadata = dict [str , tuple [CommonAttentionMetadata , Any ]]()
3119
3116
3120
- return common_metadata , builder .build (
3121
- common_prefix_len = 0 , # No cascade for encoder
3122
- common_attn_metadata = common_metadata ,
3123
- )
3117
+ for attn_group_list in self .attn_groups :
3118
+
3119
+ assert len (attn_group_list ) == 1
3120
+ attn_group = attn_group_list [0 ]
3121
+
3122
+ # Use the first attention metadata builder
3123
+ # to create encoder attention metadata
3124
+ builder = attn_group .metadata_builder
3125
+
3126
+ common_metadata = CommonAttentionMetadata (
3127
+ query_start_loc = self .query_start_loc [:num_reqs + 1 ],
3128
+ query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
3129
+ seq_lens = self .seq_lens [:num_reqs ],
3130
+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
3131
+ num_computed_tokens_cpu = self .input_batch .
3132
+ num_computed_tokens_cpu_tensor [:num_reqs ],
3133
+ num_reqs = num_reqs ,
3134
+ num_actual_tokens = total_num_scheduled_tokens ,
3135
+ max_query_len = max_num_scheduled_tokens ,
3136
+ block_table_tensor = dummy_block_table ,
3137
+ slot_mapping = dummy_slot_mapping ,
3138
+ causal = False ,
3139
+ )
3140
+
3141
+ metadata = builder .build (
3142
+ common_prefix_len = 0 , # No cascade for encoder
3143
+ common_attn_metadata = common_metadata ,
3144
+ )
3145
+
3146
+ for layer_name in attn_group .layer_names :
3147
+ group_metadata [layer_name ] = (common_metadata , metadata )
3148
+
3149
+ return group_metadata
0 commit comments