1111
1212from abc import ABC , abstractmethod
1313from dataclasses import dataclass
14- from typing import (
15- Callable ,
16- Dict ,
17- List ,
18- Literal ,
19- Optional ,
20- Protocol ,
21- Sequence ,
22- Set ,
23- Tuple ,
24- Type ,
25- Union ,
26- )
14+ from typing import Dict , List , Literal , Optional , Protocol , Sequence , Set , Tuple , Type , Union
2715
2816import torch
2917from torch ._ops import OpOverloadPacket
30- from torch .export import Dim
3118from torch .fx import Node
3219from torch .types import Number
3320
3421from ...._utils import nvtx_range
3522from ..utils .logger import ad_logger
3623
37- DynamicShape = Dict [int , Dim ] # indicating the dynamic shape in tensor dimension
38- DynamicShapeCallback = Callable [[], DynamicShape ]
39-
4024Constant = Union [int , float , str , None ]
4125
4226
@@ -67,12 +51,6 @@ class SequenceInfo:
6751 ### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ##################################################
6852 Those are extra arguments that can be provided to the interface and they are stored as follows:
6953 - _extra_args: dictionary of extra arguments with currently active values.
70- - _extra_none_inputs: dictionary of none inputs to the extra arguments.
71- NOTE: we assume that extra arguments are *optional* arguments to the model. However, we
72- cannot represent them via `None` since fx graphs require a fixed input type. Instead,
73- we require a special placeholder tensor to represent the `None` input.
74- - _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of
75- the extra arguments.
7654
7755 ### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############
7856 - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
@@ -175,12 +153,6 @@ def __init__(
175153 # indicator if extra args are activated that are needed for cached attention backends
176154 self ._is_cached_attn = False
177155
178- # indicator how to handle the "None" input for extra args
179- self ._use_strict_args = True
180-
181- # container for dynamic shapes
182- self ._dynamic_shapes : Optional [Dict [str , DynamicShape ]] = None
183-
184156 # TENSOR FIELDS ############################################################################
185157 self ._args_device : Dict [str , torch .Tensor ] = {
186158 # TENSOR FIELDS FOR UNCACHED ATTENTION
@@ -206,9 +178,6 @@ def __init__(
206178
207179 # EXTRA TENSOR FIELDS ######################################################################
208180 self ._extra_args : Dict [str , Optional [torch .Tensor ]] = {}
209- self ._extra_none_inputs : Dict [str , torch .Tensor ] = {}
210- self ._extra_dynamic_shapes : Optional [Dict [str , DynamicShape ]] = None
211- self ._extra_dynamic_shapes_callbacks : Dict [str , DynamicShapeCallback ] = {}
212181 ############################################################################################
213182
214183 # call reset once to set a consistent initial state
@@ -218,33 +187,6 @@ def __init__(
218187 def device (self ) -> torch .device :
219188 return self ._args_device ["input_ids" ].device
220189
221- @property
222- def use_strict_args (self ) -> bool :
223- return self ._use_strict_args
224-
225- @use_strict_args .setter
226- def use_strict_args (self , val : bool ) -> None :
227- """Configure whether to use strict graph arguments only.
228-
229- Args:
230- val: strict graph arguments only or not.
231-
232- In strict arguments mode,
233- * only stock arguments (like input_ids, position_ids, etc.) or extra
234- arguments that are explicitly added via the ``add_extra_arg`` interface are allowed.
235- Other arguments that are provided in ``nest_sequences`` will be rejected and throw an
236- error.
237- * registered extra arguments that are not provided to ``nest_sequences`` will be added to
238- the argument list automatically using the registered None-like tensor.
239-
240- In non-strict argument mode,
241- * all arguments including all **kwargs that are provided to ``nest_sequences`` and will
242- simply be passed to the model in the order received.
243- * registered extra arguments that are not provided to ``nest_sequences`` will be added
244- _not_ be added to the argument list.
245- """
246- self ._use_strict_args = val
247-
248190 def _shape_for_forward (self , tnsr : torch .Tensor ) -> torch .Tensor :
249191 """Shape the tensor for the forward pass based on the current attention mode.
250192
@@ -325,7 +267,11 @@ def args_for_prepare_metadata(self) -> Tuple[str, ...]:
325267 like ``insert_cached_attention`` to extract the constant arguments and add them to the
326268 ``prepare_metadata`` node/op.
327269 """
328- return tuple (self .named_standard_args .keys ())
270+ # NOTE: for now we do _not_ include input_ids since we are not guaranteed that input_ids
271+ # is part of the graph, e.g., in situations where the graph is a submodule of the overall
272+ # model. In such instances, the graph usually sees inputs_embeds. However, we assume for
273+ # now that position_ids is always part of the graph.
274+ return ("position_ids" ,) + self ._cached_arg_names
329275
330276 @property
331277 def const_args_for_prepare_metadata (self ) -> Tuple [Constant , ...]:
@@ -343,36 +289,6 @@ def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
343289 """
344290 return tuple (getattr (self , k ) for k in self ._cached_constants )
345291
346- @property
347- def named_dynamic_shapes (self ) -> Dict [str , DynamicShape ]:
348- """Return dynamic shapes of sequence info tensors.
349-
350- NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing.
351- """
352- # lazy initialization of dynamic shapes with Dim objects
353- if self ._dynamic_shapes is None :
354- # set up shape for uncached args (same for all, i.e., batch_size and seq_len)
355- bs_seq_len_shape : DynamicShape = {}
356- if self .max_batch_size > 1 :
357- bs_seq_len_shape [0 ] = Dim ("batch_size" , max = self .max_batch_size )
358- bs_seq_len_shape [1 ] = Dim ("seq_len" , max = self .max_seq_len )
359- # bs_seq_len_shape[1] = Dim.AUTO
360- self ._dynamic_shapes = {k : bs_seq_len_shape for k in self ._uncached_arg_names }
361- # cached args are static
362- self ._dynamic_shapes .update ({k : {} for k in self ._cached_arg_names })
363-
364- for k , callback in self ._extra_dynamic_shapes_callbacks .items ():
365- if k not in self ._dynamic_shapes :
366- self ._dynamic_shapes [k ] = callback ()
367-
368- # return dynamic shapes according to currently active named_args with consistent order
369- return {k : self ._dynamic_shapes [k ] for k in self .named_args .keys ()}
370-
371- @property
372- def dynamic_shapes (self ) -> Tuple [DynamicShape , ...]:
373- """Return dynamic shapes of sequence info tensors."""
374- return tuple (self .named_dynamic_shapes .values ())
375-
376292 @property
377293 def seq_len (self ) -> List [int ]:
378294 return self ._args_host ["seq_len" ].copy ()
@@ -466,7 +382,9 @@ def _get_cache_locations_and_pages_per_sequence(
466382 return cache_loc_flat , pages_per_seq
467383
468384 @classmethod
469- def _get_sanitized_seq_len (cls , input_ids : torch .Tensor , seq_len : torch .Tensor ) -> torch .Tensor :
385+ def _get_sanitized_seq_len (
386+ cls , input_or_position_ids : torch .Tensor , seq_len : torch .Tensor
387+ ) -> torch .Tensor :
470388 """Sanitize sequence lengths.
471389
472390 We want to cover the following scenarios with this function:
@@ -499,22 +417,24 @@ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor)
499417 # valid cache location in the batch. This would ensure that the dummy sequences just
500418 # repeats valid computation...
501419 """
502- _ , s = input_ids .shape [:2 ]
503- num_seq = cls ._get_sanitized_num_sequences (input_ids , seq_len )
420+ _ , s = input_or_position_ids .shape [:2 ]
421+ num_seq = cls ._get_sanitized_num_sequences (input_or_position_ids , seq_len )
504422 if s > 1 :
505423 return seq_len [:num_seq ].detach ().clone ()
506424 else :
507425 return torch .ones (num_seq , dtype = seq_len .dtype , device = seq_len .device )
508426
509427 @staticmethod
510- def _get_sanitized_num_sequences (input_ids : torch .Tensor , seq_len : torch .Tensor ) -> int :
428+ def _get_sanitized_num_sequences (
429+ input_or_position_ids : torch .Tensor , seq_len : torch .Tensor
430+ ) -> int :
511431 """Get number of sequences.
512432
513433 We makes sure that this function is compatible with both torch graph capture and cudagraph.
514434 Both can be a bit temparamental when trying to extract the number of sequences from a tensor
515435 with max_batch_size or max_batch_size*max_seq_len.
516436 """
517- b , s = input_ids .shape [:2 ]
437+ b , s = input_or_position_ids .shape [:2 ]
518438 if s > 1 :
519439 num_seq = torch .sum (seq_len > 0 )
520440 assert seq_len [num_seq :].sum () == 0 , "seq_len should be zero-padded"
@@ -547,12 +467,11 @@ def _move_dict(d: Dict[str, torch.Tensor]) -> None:
547467
548468 _move_dict (self ._args_device )
549469 _move_dict (self ._extra_args )
550- _move_dict (self ._extra_none_inputs )
551470
552471 def set_example_sequence (
553472 self ,
554- input_ids : Sequence [Sequence [int ]] = None ,
555- position_ids : Optional [torch . Tensor ] = None ,
473+ input_ids : Optional [ Sequence [Sequence [int ] ]] = None ,
474+ position_ids : Optional [Sequence [ Sequence [ int ]] ] = None ,
556475 ** extra_args ,
557476 ) -> None :
558477 """Set an example sequence useful for testing and export purposes without cache history."""
@@ -652,8 +571,6 @@ def _store_extra_arg(
652571 else :
653572 tnsr_like = tnsr_like [0 ]
654573 self ._extra_args [name ] = tnsr_like .to (self .device , non_blocking = True )
655- elif self .use_strict_args :
656- self ._extra_args [name ] = self ._extra_none_inputs [name ]
657574 else :
658575 self ._extra_args [name ] = None
659576
@@ -736,15 +653,8 @@ def nest_sequences(
736653
737654 ### UPDATE EXTRA INPUTS ####################################################################
738655 self ._extra_args = {}
739- # in strict argument mode, we only accept registered extra arguments
740- if self .use_strict_args :
741- for name in self ._extra_none_inputs .keys ():
742- self ._store_extra_arg (name , extra_args .pop (name , None ))
743- assert not extra_args , f"Extra arguments { extra_args .keys ()} not found"
744- # otherwise, we simply pass in all extra arguments
745- else :
746- for key , value in extra_args .items ():
747- self ._store_extra_arg (key , value )
656+ for key , value in extra_args .items ():
657+ self ._store_extra_arg (key , value )
748658
749659 @nvtx_range ("ad_rescatter_input_ids" )
750660 def rescatter_input_ids (
@@ -778,31 +688,6 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
778688 t_squeezed = t_nested .squeeze (1 ) if self .is_generate else t_nested .squeeze (0 )
779689 return list (torch .split (t_squeezed , self .seq_len ))
780690
781- def add_extra_arg (
782- self ,
783- name : str ,
784- none_input : torch .Tensor ,
785- dynamic_shape_callback : Optional [DynamicShapeCallback ] = None ,
786- ) -> None :
787- """Add an extra argument to the sequence info object.
788-
789- Args:
790- name: The name of the extra argument.
791- none_input: None input value of the extra argument.
792- dynamic_shape_callback: The callback to get the dynamic shape of the extra argument.
793-
794- Note that the extra argument is expected to be a tensor.
795- """
796- assert name not in self ._named_args ().keys (), f"Extra argument { name } already exists"
797-
798- self ._extra_args [name ] = none_input .to (self .device )
799- self ._extra_none_inputs [name ] = self ._extra_args [name ]
800-
801- if dynamic_shape_callback is None :
802- self ._extra_dynamic_shapes_callbacks [name ] = lambda : {}
803- else :
804- self ._extra_dynamic_shapes_callbacks [name ] = dynamic_shape_callback
805-
806691
807692class MHACallable (Protocol ):
808693 def __call__ (
@@ -814,7 +699,6 @@ def __call__(
814699class PrepareMetadataCallable (Protocol ):
815700 def __call__ (
816701 self ,
817- input_ids : torch .Tensor ,
818702 position_ids : torch .Tensor ,
819703 seq_len : torch .Tensor ,
820704 input_pos : torch .Tensor ,
@@ -901,7 +785,6 @@ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
901785
902786 ```
903787 def prepare_metadata(
904- input_ids: torch.Tensor,
905788 position_ids: torch.Tensor,
906789 seq_len: torch.Tensor,
907790 input_pos: torch.Tensor,
0 commit comments