@@ -1006,6 +1006,21 @@ def build_active_slices(self, batch_size: int):
10061006 graph_scratch_space = torch .cumsum (self .active_request_query_lengths [:batch_size ], dim = 0 )
10071007 self .active_request_last_token_idxs [:batch_size ].copy_ (graph_scratch_space - 1 )
10081008
1009+ def pad_active_slices (self ):
1010+ """Pad the active slices of specific tensors."""
1011+ # Some tensors need to be padded at the token level.
1012+ padding_token_slice = slice (self .active_token_count , self .padded_active_token_count )
1013+
1014+ self .token_to_block_idx [padding_token_slice ] = self .kv_block_allocator .dummy_block_idx
1015+ self .token_to_local_position_within_kv_block [padding_token_slice ] = 0
1016+ self .token_to_position_in_request [padding_token_slice ] = 0
1017+
1018+ # Other tensors need to be padded at the request level.
1019+ padding_request_slice = slice (
1020+ self .total_request_count - self .paused_request_count ,
1021+ self .padded_active_request_count ,
1022+ )
1023+
10091024 def append_key_value_cache (self , layer_number : int , key : Tensor , value : Tensor ) -> None :
10101025 """Append to KV cache.
10111026
@@ -1620,23 +1635,6 @@ def initialize_attention_state(
16201635 prefill_req_count = padded_prefill_req_count ,
16211636 decode_req_count = padded_decode_req_count ,
16221637 )
1623- self .padded_active_token_count = self .padded_batch_dimensions .token_count
1624- self .padded_active_request_count = self .padded_batch_dimensions .req_count
1625- self .padding_slice = slice (self .active_token_count , self .padded_active_token_count )
1626-
1627- self .build_active_slices (self .padded_active_request_count )
1628- batch_size = self .total_request_count - self .paused_request_count
1629-
1630- # Update token position indexes.
1631- self .token_to_block_idx [self .active_token_count : self .padded_active_token_count ] = (
1632- self .kv_block_allocator .dummy_block_idx
1633- )
1634- self .token_to_local_position_within_kv_block [
1635- self .active_token_count : self .padded_active_token_count
1636- ] = 0
1637- self .token_to_position_in_request [
1638- self .active_token_count : self .padded_active_token_count
1639- ] = 0
16401638
16411639 self .active_attn_metadata = (
16421640 self .graph_attn_metadata # type: ignore[assignment]
@@ -1657,6 +1655,14 @@ def initialize_attention_state(
16571655 decode_req_count = adjusted_decode_req_count ,
16581656 )
16591657
1658+ self .padded_active_token_count = self .padded_batch_dimensions .token_count
1659+ self .padded_active_request_count = self .padded_batch_dimensions .req_count
1660+ self .padding_slice = slice (self .active_token_count , self .padded_active_token_count )
1661+
1662+ self .build_active_slices (self .padded_active_request_count )
1663+ self .pad_active_slices ()
1664+
1665+ batch_size = self .total_request_count - self .paused_request_count
16601666 assert self .active_attn_metadata is not None
16611667 self .active_attn_metadata ["mha_metadata" ].update (
16621668 request_query_lengths = self .active_request_query_lengths [:batch_size ],
0 commit comments