@@ -75,11 +75,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
7575 # Maximum query length in the batch.
7676 max_query_len : Optional [int ]
7777
78- # Number of query tokens for each request in the batch.
79- # Currently, we require that all requests have the same number of query
80- # tokens during the decoding phase. When speculavie decoding is enabled,
81- # decode_query_len might be greater than 1. In all other cases, it is 1.
82- decode_query_len : Optional [int ]
78+ # Max number of query tokens among request in the batch.
79+ max_decode_query_len : Optional [int ]
8380
8481 # Maximum sequence length among prefill batch. 0 if there are decoding
8582 # requests only.
@@ -140,7 +137,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
140137 slot_mapping = slot_mapping ,
141138 seq_lens = self .seq_lens [:self .num_prefills ],
142139 seq_lens_tensor = self .seq_lens_tensor [:self .num_prefills ],
143- decode_query_len = 0 ,
140+ max_decode_query_len = 0 ,
144141 max_query_len = self .max_query_len ,
145142 max_prefill_seq_len = self .max_prefill_seq_len ,
146143 max_decode_seq_len = 0 ,
@@ -172,7 +169,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
172169 slot_mapping = slot_mapping ,
173170 seq_lens = None ,
174171 seq_lens_tensor = self .seq_lens_tensor [self .num_prefills :],
175- decode_query_len = self .decode_query_len ,
172+ max_decode_query_len = self .max_decode_query_len ,
176173 max_query_len = None ,
177174 max_prefill_seq_len = 0 ,
178175 max_decode_seq_len = self .max_decode_seq_len ,
@@ -256,9 +253,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
256253 max_query_len = max (query_lens )
257254 decode_query_lens = query_lens [self .num_prefills :]
258255 if len (decode_query_lens ) > 0 :
259- decode_query_len = max (decode_query_lens )
256+ max_decode_query_len = max (decode_query_lens )
260257 else :
261- decode_query_len = 1
258+ max_decode_query_len = 1
262259 max_prefill_seq_len = max (self .prefill_seq_lens , default = 0 )
263260 max_decode_seq_len = max (self .curr_seq_lens , default = 0 )
264261 num_decode_tokens = self .num_decode_tokens
@@ -304,7 +301,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
304301 seq_lens = seq_lens ,
305302 seq_lens_tensor = seq_lens_tensor ,
306303 max_query_len = max_query_len ,
307- decode_query_len = decode_query_len ,
304+ max_decode_query_len = max_decode_query_len ,
308305 max_prefill_seq_len = max_prefill_seq_len ,
309306 max_decode_seq_len = max_decode_seq_len ,
310307 query_start_loc = query_start_loc ,
0 commit comments