@@ -72,6 +72,8 @@ class SequenceInfo:
7272 - pages_per_seq: [ps_0, ps_1, ..., ps_{b-1}] where ps_i is the number of pages allocated for
7373 sequence i. Note that, for example, cache_loc[p_0:p_1] will correspond to the pages associated
7474 with sequence 1 in the batch.
75+ - slot_idx: [s_0, s_1, ..., s_{b-1}]
76+ Corresponds to the slot index of each sequence in the batch.
7577
7678 ################################################################################################
7779
@@ -134,7 +136,8 @@ def __init__(
134136 self ._num_pages = max (
135137 self .max_batch_size ,
136138 (self .max_num_tokens ) // self .page_size # floored number of pages
137- + (self .max_num_tokens % self .page_size > 0 ) * self .max_batch_size , # +1 per sequence
139+ + (self .max_num_tokens / self .max_batch_size % self .page_size > 0 ) # check for overflow
140+ * self .max_batch_size , # +1 page per sequence if overflow is required
138141 )
139142 # sanity check
140143 assert self .num_pages >= self .max_batch_size , "num_pages can't be less than max_batch_size"
@@ -164,6 +167,7 @@ def __init__(
164167 "input_pos" : torch .empty (self .max_batch_size , dtype = torch .int ),
165168 "cache_loc" : torch .empty (self .num_pages , dtype = torch .int ),
166169 "pages_per_seq" : torch .empty (self .max_batch_size , dtype = torch .int ),
170+ "slot_idx" : torch .empty (self .max_batch_size , dtype = torch .int ),
167171 # OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
168172 "_gather_idx" : torch .empty (self .max_num_tokens , dtype = torch .int ),
169173 }
@@ -172,7 +176,8 @@ def __init__(
172176 }
173177 # NOTE: order of keys is relevant here!
174178 self ._uncached_arg_names = ("input_ids" , "position_ids" )
175- self ._cached_arg_names = ("seq_len" , "input_pos" , "cache_loc" , "pages_per_seq" )
179+ self ._cached_arg_names = ("seq_len" , "input_pos" , "cache_loc" , "pages_per_seq" , "slot_idx" )
180+ self ._cached_constants = ("page_size" ,)
176181 ############################################################################################
177182
178183 # EXTRA TENSOR FIELDS ######################################################################
@@ -296,7 +301,7 @@ def const_args_for_prepare_metadata(self) -> Tuple:
296301 ``insert_cached_attention`` to extract the constant arguments and add them to the
297302 ``prepare_metadata`` node/op.
298303 """
299- return ( self . page_size , )
304+ return tuple ( getattr ( self , k ) for k in self . _cached_constants )
300305
301306 @property
302307 def named_dynamic_shapes (self ) -> Dict [str , Dict [str , Dim ]]:
@@ -311,6 +316,7 @@ def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]:
311316 if self .max_batch_size > 1 :
312317 bs_seq_len_shape [0 ] = Dim ("batch_size" , max = self .max_batch_size )
313318 bs_seq_len_shape [1 ] = Dim ("seq_len" , max = self .max_seq_len )
319+ # bs_seq_len_shape[1] = Dim.AUTO
314320 self ._dynamic_shapes = {k : bs_seq_len_shape for k in self ._uncached_arg_names }
315321 # cached args are static
316322 self ._dynamic_shapes .update ({k : {} for k in self ._cached_arg_names })
@@ -522,11 +528,15 @@ def set_example_sequence(
522528 cache_loc = list (range (sum (pages_per_seq )))
523529 page_assignments = self ._get_page_assignments (cache_loc , pages_per_seq )
524530
531+ # vanilla slot indices
532+ slot_idx = list (range (len (input_ids )))
533+
525534 self .nest_sequences (
526535 input_ids ,
527536 position_ids , # will be auto-inferred if None
528537 input_pos = 0 , # no cache history
529538 page_assignments = page_assignments , # vanilla page assignments
539+ slot_idx = slot_idx , # vanilla slot indices
530540 ** extra_args ,
531541 )
532542
@@ -613,6 +623,7 @@ def nest_sequences(
613623 position_ids : Optional [Sequence [Sequence [int ]]] = None ,
614624 input_pos : Optional [Union [Sequence [int ], int ]] = None ,
615625 page_assignments : Optional [Sequence [Sequence [int ]]] = None ,
626+ slot_idx : Optional [Sequence [int ]] = None ,
616627 ** extra_args : Dict [str , Union [torch .Tensor , Sequence [torch .Tensor ]]],
617628 ) -> None :
618629 """Create and store sequence information for the next forward pass.
@@ -622,6 +633,7 @@ def nest_sequences(
622633 position_ids: List of sequences of position_ids for each token.
623634 input_pos: Absolute starting position in the cache for each sequence.
624635 page_assignments: List of sequences of page assignments for each sequence.
636+ slot_idx: List of slot indices for each sequence.
625637 extra_args: Extra arguments to be stored in the interface.
626638
627639 This i/f will ensure that all sequence info args are updated accordingly.
@@ -648,6 +660,10 @@ def nest_sequences(
648660 self ._store_arg ("cache_loc" , cache_loc , reset = True )
649661 self ._store_arg ("pages_per_seq" , pages_per_seq , reset = True )
650662
663+ # check for updated slot_idx
664+ if slot_idx is not None :
665+ self ._store_arg ("slot_idx" , slot_idx )
666+
651667 ### UPDATE MAIN INPUTS #####################################################################
652668 # set new input_ids and make sure to flatten it
653669 self ._store_arg ("input_ids" , self ._flatten (input_ids ))
@@ -749,6 +765,7 @@ def __call__(
749765 input_pos : torch .Tensor ,
750766 cache_loc : torch .Tensor ,
751767 pages_per_seq : torch .Tensor ,
768+ slot_idx : torch .Tensor ,
752769 page_size : int ,
753770 ) -> List [torch .Tensor ]: ...
754771
@@ -834,6 +851,9 @@ def prepare_metadata(
834851 seq_len: torch.Tensor,
835852 input_pos: torch.Tensor,
836853 cache_loc: torch.Tensor,
854+ pages_per_seq: torch.Tensor,
855+ slot_idx: torch.Tensor,
856+ page_size: int,
837857 ) -> List[torch.Tensor]: ...
838858 ```
839859 The metadata should contain all necessary global information required for the underlying
0 commit comments