1010"""
1111
1212from abc import ABC , abstractmethod
13- from typing import (
14- Callable ,
15- Dict ,
16- List ,
17- Literal ,
18- Optional ,
19- Protocol ,
20- Sequence ,
21- Set ,
22- Tuple ,
23- Type ,
24- Union ,
25- )
13+ from typing import Dict , List , Literal , Optional , Protocol , Sequence , Set , Tuple , Type , Union
2614
2715import torch
2816from pydantic import BaseModel , ConfigDict , Field , field_validator
3624Constant = Union [int , float , str , None ]
3725
3826
27+ class PrepareMetadataHostCallable (Protocol ):
28+ def __call__ (self , ** sequence_info_args : torch .Tensor ) -> None : ...
29+
30+
3931class InputBuffer :
4032 """Manages contiguous memory buffers for efficient host-to-device transfers.
4133
@@ -388,6 +380,9 @@ class SequenceInfo:
388380 - _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}]
389381 Mask scatter indices used by the overlap scheduler to scatter results back.
390382
383+ NOTE: all tensors are also accessible as host tensors with the suffix "_host". For example,
384+ the tensor "batch_info" is accessible as "batch_info_host" on the host.
385+
391386 ################################################################################################
392387
393388 Here are a couple of notes to emphasize this notation:
@@ -508,24 +503,25 @@ def __init__(
508503 # Create the InputBuffer that manages contiguous host and device memory
509504 # Starts on default device; use to() to move to target device
510505 self ._input_buffer = InputBuffer (tensor_specs )
506+ self ._available_args = set (self ._input_buffer .tensor_names ) | {
507+ f"{ name } _host" for name in self ._input_buffer .tensor_names
508+ }
511509
512510 # Initialize args_list from tensor specs
513511 self ._args_list : Dict [str , List [int ]] = {
514512 name : [0 ] * numel for name , numel , _ in tensor_specs
515513 }
516514
517515 self ._active_args = ("input_ids" , "position_ids" )
518- self ._shapeable_args = ("input_ids" , "position_ids" )
519- # Args that should be returned from host (pinned memory) instead of device in _named_args
520- self ._host_return_args = ("batch_info" , "logits_gather_info" )
516+ self ._shapeable_args = ("input_ids" , "position_ids" , "input_ids_host" , "position_ids_host" )
521517 ############################################################################################
522518
523519 # EXTRA TENSOR FIELDS ######################################################################
524520 self ._extra_args : Dict [str , Optional [torch .Tensor ]] = {}
525521 ############################################################################################
526522
527523 # HOST PREPARE FOR ATTENTION FORWARD #######################################################
528- self ._host_prepare_functions : set [ Callable [[ SequenceInfo ], None ]] = set ()
524+ self ._host_prepare_functions : List [ Tuple [ PrepareMetadataHostCallable , List [ str ]]] = []
529525
530526 # call reset once to set a consistent initial state
531527 self .reset ()
@@ -558,14 +554,13 @@ def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
558554
559555 def _get_arg (self , name : str ) -> torch .Tensor :
560556 """Get the argument from the input buffer either on device or host."""
561- if name in self . _host_return_args :
562- arg = self ._input_buffer .get_host_view (name )
557+ if name . endswith ( "_host" ) :
558+ arg = self ._input_buffer .get_host_view (name . replace ( "_host" , "" ) )
563559 else :
564560 arg = self ._input_buffer .get_view (name )
565561 return self ._shape_for_forward (arg ) if name in self ._shapeable_args else arg
566562
567563 def _named_args (self , include_extra_args : bool = True ) -> Dict [str , torch .Tensor ]:
568- # Build args dict, using host views for _host_return_args, device views otherwise
569564 args = {k : self ._get_arg (k ) for k in self ._active_args }
570565
571566 # check other args to include
@@ -577,7 +572,7 @@ def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor
577572 @property
578573 def available_args (self ) -> Set [str ]:
579574 """Return a list of available arguments."""
580- return set ( self ._input_buffer . tensor_names )
575+ return self ._available_args
581576
582577 @property
583578 def named_args (self ) -> Dict [str , torch .Tensor ]:
@@ -697,68 +692,6 @@ def _get_cache_locations_and_pages_per_sequence(
697692 pages_per_seq = [len (p ) for p in page_assignments ]
698693 return cache_loc_flat , pages_per_seq
699694
700- # TODO: remove after updating all cached backends
701- @classmethod
702- def _get_sanitized_seq_len (
703- cls , input_or_position_ids : torch .Tensor , seq_len : torch .Tensor
704- ) -> torch .Tensor :
705- """Sanitize sequence lengths.
706-
707- We want to cover the following scenarios with this function:
708-
709- 1. Pre-fill:
710- input_ids: [1, s_total, ...]
711- seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
712- ---> returns [s_0, s_1, ..., s_{b-1}]
713- 2. Decode:
714- input_ids: [b, 1, ...]
715- seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
716- |---- b ----|--- (max_batch_size - b) ---|
717- --> returns [1,] * b
718- 3. Decode in Cudagraph:
719- input_ids: [b_cudagraph, 1, ...]
720- seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
721- |---- b ----|--- (max_batch_size - b) ---|
722-
723- --> returns [1,] * b_cudagraph
724- Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
725- b_cudagraph.
726-
727- # TODO: I could see one possible issue with this approach in the future.
728- # If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
729- # information. What could happen is that the for the padded sequences the cache location
730- # tensors point to allocated pages. This could lead to a situation where we write into
731- # allocated cache pages polluting the cache of other sequences. Now this is not an issue
732- # if we write the dummy sequences into unallocated cache pages... One fix could be to
733- # pad not only the seq len but also pad the cache locations by just repeating the last
734- # valid cache location in the batch. This would ensure that the dummy sequences just
735- # repeats valid computation...
736- """
737- _ , s = input_or_position_ids .shape [:2 ]
738- num_seq = cls ._get_sanitized_num_sequences (input_or_position_ids , seq_len )
739- if s > 1 :
740- return seq_len [:num_seq ].clone ()
741- else :
742- return torch .ones (num_seq , dtype = seq_len .dtype , device = seq_len .device )
743-
744- @staticmethod
745- def _get_sanitized_num_sequences (
746- input_or_position_ids : torch .Tensor , seq_len : torch .Tensor
747- ) -> int :
748- """Get number of sequences.
749-
750- We makes sure that this function is compatible with both torch graph capture and cudagraph.
751- Both can be a bit temparamental when trying to extract the number of sequences from a tensor
752- with max_batch_size or max_batch_size*max_seq_len.
753- """
754- b , s = input_or_position_ids .shape [:2 ]
755- if s > 1 :
756- num_seq = torch .sum (seq_len > 0 )
757- assert seq_len [num_seq :].sum () == 0 , "seq_len should be zero-padded"
758- else :
759- num_seq = b
760- return num_seq
761-
762695 def activate_arg (self , arg_name : str ) -> bool :
763696 """Activate a desired argument.
764697
@@ -869,7 +802,7 @@ def _store_arg(
869802 self ._args_list [name ] = tnsr_like .copy ()
870803
871804 # Only store to buffer when the argument is active or force_copy is True
872- if not (name in self ._active_args or force_copy ):
805+ if not (name in self ._active_args or f" { name } _host" in self . _active_args or force_copy ):
873806 return
874807
875808 # Store to the InputBuffer's pinned host memory
@@ -1090,12 +1023,12 @@ def rescatter_input_ids(self, ungathered_input_ids: torch.Tensor):
10901023 def maybe_gather_and_squeeze_logits (self , logits : torch .Tensor ) -> torch .Tensor :
10911024 """Maybe gather the logits if logits have not been gathered yet."""
10921025 num_tokens = logits .shape [0 ] * logits .shape [1 ]
1093- num_tokens_to_gather , gather_required = self ._get_arg ("logits_gather_info " ).tolist ()
1026+ num_tokens_to_gather , gather_required = self ._get_arg ("logits_gather_info_host " ).tolist ()
10941027 if gather_required and num_tokens_to_gather < num_tokens :
10951028 logits = torch .ops .auto_deploy .gather_logits_before_lm_head (
10961029 logits ,
10971030 self ._get_arg ("logits_gather_indices" ),
1098- self ._get_arg ("logits_gather_info " ),
1031+ self ._get_arg ("logits_gather_info_host " ),
10991032 )
11001033 return logits .squeeze (int (self .is_generate ))
11011034
@@ -1105,13 +1038,13 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
11051038 return list (torch .split (t_squeezed , self .seq_len ))
11061039
11071040 def register_host_prepare_for_attention_forward (
1108- self , host_function : Callable [[ "SequenceInfo" ], None ]
1041+ self , host_function : PrepareMetadataHostCallable , args : List [ str ]
11091042 ):
1110- self ._host_prepare_functions .add ( host_function )
1043+ self ._host_prepare_functions .append (( host_function , args ) )
11111044
11121045 def run_host_prepare_for_attention_forward (self ) -> None :
1113- for host_function in self ._host_prepare_functions :
1114- host_function (self )
1046+ for host_function , args in self ._host_prepare_functions :
1047+ host_function (** { arg : self . _get_arg ( arg ) for arg in args } )
11151048
11161049
11171050class MHACallable (Protocol ):
@@ -1123,14 +1056,7 @@ def __call__(
11231056
11241057class PrepareMetadataCallable (Protocol ):
11251058 def __call__ (
1126- self ,
1127- position_ids : torch .Tensor ,
1128- seq_len : torch .Tensor ,
1129- input_pos : torch .Tensor ,
1130- cache_loc : torch .Tensor ,
1131- pages_per_seq : torch .Tensor ,
1132- slot_idx : torch .Tensor ,
1133- page_size : int ,
1059+ self , * sequence_info_args_and_constants : Union [torch .Tensor , Constant ]
11341060 ) -> List [torch .Tensor ]: ...
11351061
11361062
@@ -1291,13 +1217,14 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
12911217 return []
12921218
12931219 @classmethod
1294- def host_prepare_for_forward (cls , sequence_info : SequenceInfo ) :
1295- """Perform host-side preparation for the forward pass for the attention op.
1220+ def get_host_prepare_metadata_function (cls ) -> Optional [ PrepareMetadataHostCallable ] :
1221+ """Get function that performs host-side prep for the forward pass for the attention op.
12961222
12971223 This method is responsible for preparing the attention op for the forward pass.
1298- This function is not expected to be graph capturable or compatible with cuda graphs.
1224+ This function is not expected to be graph capturable or compatible with cuda graphs. It can
1225+ use any argument from the SequenceInfo interface as input argument to its function.
12991226 """
1300- return
1227+ return None
13011228
13021229
13031230class AttentionRegistry :
0 commit comments