Skip to content

Commit 19421f4

Browse files
lucasliedominicshanshan
authored andcommitted
[None][feat] AutoDeploy: VLMs with subgraphs + cudagraph/compile (NVIDIA#8203)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 4f80961 commit 19421f4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+820
-588
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,14 @@ transforms:
66
############################################################################################
77
build_model:
88
stage: factory
9+
run_per_gm: false
910
device: meta
10-
# nothing to clean up
11-
run_graph_cleanup: false
1211
requires_clean_graph: false
1312
export_to_gm:
1413
stage: export
1514
clone_state_dict: false
1615
strict: false
17-
# nothing to clean up
18-
run_graph_cleanup: false
16+
run_per_gm: false
1917
requires_clean_graph: false
2018
cleanup_noop_slice:
2119
stage: post_export
@@ -35,6 +33,7 @@ transforms:
3533
run_shape_prop: true
3634
match_eager_attention:
3735
stage: pattern_matcher
36+
requires_shape_prop: true
3837
match_grouped_attention:
3938
stage: pattern_matcher
4039
match_attention_layout:
@@ -87,8 +86,10 @@ transforms:
8786
############################################################################################
8887
load_weights:
8988
stage: weight_load
89+
run_per_gm: false
9090
move_inputs_to_device:
9191
stage: weight_load
92+
run_per_gm: false
9293
############################################################################################
9394
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
9495
############################################################################################
@@ -138,10 +139,13 @@ transforms:
138139
attn_backend: cuda_causal_conv
139140
initialize_cache:
140141
stage: cache_init
142+
run_per_gm: false
141143
resize_kv_cache:
142144
stage: cache_init
145+
run_per_gm: false
143146
############################################################################################
144147
# COMPILE MODEL
145148
############################################################################################
146149
compile_model:
147150
stage: compile
151+
run_per_gm: false

tensorrt_llm/_torch/auto_deploy/config/transformers.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,29 @@ transforms:
66
############################################################################################
77
build_and_load_factory_model:
88
stage: factory
9+
run_per_gm: false
910
############################################################################################
1011
# MOVE ARGUMENTS TO DEVICE
1112
############################################################################################
1213
move_inputs_to_device:
1314
stage: weight_load
15+
run_per_gm: false
1416
############################################################################################
1517
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
1618
############################################################################################
1719
detect_hf_attn_layers:
1820
stage: cache_init
21+
run_per_gm: false
1922
transformers_replace_cached_attn:
2023
stage: cache_init
2124
attn_backend: flashinfer
25+
run_per_gm: false
2226
initialize_cache:
2327
stage: cache_init
28+
run_per_gm: false
2429
resize_kv_cache:
2530
stage: cache_init
31+
run_per_gm: false
2632
############################################################################################
2733
# COMPILE MODEL
2834
############################################################################################

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 19 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,16 @@
1111

1212
from abc import ABC, abstractmethod
1313
from dataclasses import dataclass
14-
from typing import (
15-
Callable,
16-
Dict,
17-
List,
18-
Literal,
19-
Optional,
20-
Protocol,
21-
Sequence,
22-
Set,
23-
Tuple,
24-
Type,
25-
Union,
26-
)
14+
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
2715

2816
import torch
2917
from torch._ops import OpOverloadPacket
30-
from torch.export import Dim
3118
from torch.fx import Node
3219
from torch.types import Number
3320

3421
from ...._utils import nvtx_range
3522
from ..utils.logger import ad_logger
3623

37-
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
38-
DynamicShapeCallback = Callable[[], DynamicShape]
39-
4024
Constant = Union[int, float, str, None]
4125

4226

@@ -67,12 +51,6 @@ class SequenceInfo:
6751
### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ##################################################
6852
Those are extra arguments that can be provided to the interface and they are stored as follows:
6953
- _extra_args: dictionary of extra arguments with currently active values.
70-
- _extra_none_inputs: dictionary of none inputs to the extra arguments.
71-
NOTE: we assume that extra arguments are *optional* arguments to the model. However, we
72-
cannot represent them via `None` since fx graphs require a fixed input type. Instead,
73-
we require a special placeholder tensor to represent the `None` input.
74-
- _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of
75-
the extra arguments.
7654
7755
### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############
7856
- seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
@@ -175,12 +153,6 @@ def __init__(
175153
# indicator if extra args are activated that are needed for cached attention backends
176154
self._is_cached_attn = False
177155

178-
# indicator how to handle the "None" input for extra args
179-
self._use_strict_args = True
180-
181-
# container for dynamic shapes
182-
self._dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
183-
184156
# TENSOR FIELDS ############################################################################
185157
self._args_device: Dict[str, torch.Tensor] = {
186158
# TENSOR FIELDS FOR UNCACHED ATTENTION
@@ -206,9 +178,6 @@ def __init__(
206178

207179
# EXTRA TENSOR FIELDS ######################################################################
208180
self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
209-
self._extra_none_inputs: Dict[str, torch.Tensor] = {}
210-
self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
211-
self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {}
212181
############################################################################################
213182

214183
# call reset once to set a consistent initial state
@@ -218,33 +187,6 @@ def __init__(
218187
def device(self) -> torch.device:
219188
return self._args_device["input_ids"].device
220189

221-
@property
222-
def use_strict_args(self) -> bool:
223-
return self._use_strict_args
224-
225-
@use_strict_args.setter
226-
def use_strict_args(self, val: bool) -> None:
227-
"""Configure whether to use strict graph arguments only.
228-
229-
Args:
230-
val: strict graph arguments only or not.
231-
232-
In strict arguments mode,
233-
* only stock arguments (like input_ids, position_ids, etc.) or extra
234-
arguments that are explicitly added via the ``add_extra_arg`` interface are allowed.
235-
Other arguments that are provided in ``nest_sequences`` will be rejected and throw an
236-
error.
237-
* registered extra arguments that are not provided to ``nest_sequences`` will be added to
238-
the argument list automatically using the registered None-like tensor.
239-
240-
In non-strict argument mode,
241-
* all arguments including all **kwargs that are provided to ``nest_sequences`` and will
242-
simply be passed to the model in the order received.
243-
* registered extra arguments that are not provided to ``nest_sequences`` will be added
244-
_not_ be added to the argument list.
245-
"""
246-
self._use_strict_args = val
247-
248190
def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
249191
"""Shape the tensor for the forward pass based on the current attention mode.
250192
@@ -325,7 +267,11 @@ def args_for_prepare_metadata(self) -> Tuple[str, ...]:
325267
like ``insert_cached_attention`` to extract the constant arguments and add them to the
326268
``prepare_metadata`` node/op.
327269
"""
328-
return tuple(self.named_standard_args.keys())
270+
# NOTE: for now we do _not_ include input_ids since we are not guaranteed that input_ids
271+
# is part of the graph, e.g., in situations where the graph is a submodule of the overall
272+
# model. In such instances, the graph usually sees inputs_embeds. However, we assume for
273+
# now that position_ids is always part of the graph.
274+
return ("position_ids",) + self._cached_arg_names
329275

330276
@property
331277
def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
@@ -343,36 +289,6 @@ def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
343289
"""
344290
return tuple(getattr(self, k) for k in self._cached_constants)
345291

346-
@property
347-
def named_dynamic_shapes(self) -> Dict[str, DynamicShape]:
348-
"""Return dynamic shapes of sequence info tensors.
349-
350-
NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing.
351-
"""
352-
# lazy initialization of dynamic shapes with Dim objects
353-
if self._dynamic_shapes is None:
354-
# set up shape for uncached args (same for all, i.e., batch_size and seq_len)
355-
bs_seq_len_shape: DynamicShape = {}
356-
if self.max_batch_size > 1:
357-
bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size)
358-
bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len)
359-
# bs_seq_len_shape[1] = Dim.AUTO
360-
self._dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names}
361-
# cached args are static
362-
self._dynamic_shapes.update({k: {} for k in self._cached_arg_names})
363-
364-
for k, callback in self._extra_dynamic_shapes_callbacks.items():
365-
if k not in self._dynamic_shapes:
366-
self._dynamic_shapes[k] = callback()
367-
368-
# return dynamic shapes according to currently active named_args with consistent order
369-
return {k: self._dynamic_shapes[k] for k in self.named_args.keys()}
370-
371-
@property
372-
def dynamic_shapes(self) -> Tuple[DynamicShape, ...]:
373-
"""Return dynamic shapes of sequence info tensors."""
374-
return tuple(self.named_dynamic_shapes.values())
375-
376292
@property
377293
def seq_len(self) -> List[int]:
378294
return self._args_host["seq_len"].copy()
@@ -466,7 +382,9 @@ def _get_cache_locations_and_pages_per_sequence(
466382
return cache_loc_flat, pages_per_seq
467383

468384
@classmethod
469-
def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
385+
def _get_sanitized_seq_len(
386+
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
387+
) -> torch.Tensor:
470388
"""Sanitize sequence lengths.
471389
472390
We want to cover the following scenarios with this function:
@@ -499,22 +417,24 @@ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor)
499417
# valid cache location in the batch. This would ensure that the dummy sequences just
500418
# repeats valid computation...
501419
"""
502-
_, s = input_ids.shape[:2]
503-
num_seq = cls._get_sanitized_num_sequences(input_ids, seq_len)
420+
_, s = input_or_position_ids.shape[:2]
421+
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
504422
if s > 1:
505423
return seq_len[:num_seq].detach().clone()
506424
else:
507425
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
508426

509427
@staticmethod
510-
def _get_sanitized_num_sequences(input_ids: torch.Tensor, seq_len: torch.Tensor) -> int:
428+
def _get_sanitized_num_sequences(
429+
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
430+
) -> int:
511431
"""Get number of sequences.
512432
513433
We makes sure that this function is compatible with both torch graph capture and cudagraph.
514434
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
515435
with max_batch_size or max_batch_size*max_seq_len.
516436
"""
517-
b, s = input_ids.shape[:2]
437+
b, s = input_or_position_ids.shape[:2]
518438
if s > 1:
519439
num_seq = torch.sum(seq_len > 0)
520440
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
@@ -547,12 +467,11 @@ def _move_dict(d: Dict[str, torch.Tensor]) -> None:
547467

548468
_move_dict(self._args_device)
549469
_move_dict(self._extra_args)
550-
_move_dict(self._extra_none_inputs)
551470

552471
def set_example_sequence(
553472
self,
554-
input_ids: Sequence[Sequence[int]] = None,
555-
position_ids: Optional[torch.Tensor] = None,
473+
input_ids: Optional[Sequence[Sequence[int]]] = None,
474+
position_ids: Optional[Sequence[Sequence[int]]] = None,
556475
**extra_args,
557476
) -> None:
558477
"""Set an example sequence useful for testing and export purposes without cache history."""
@@ -652,8 +571,6 @@ def _store_extra_arg(
652571
else:
653572
tnsr_like = tnsr_like[0]
654573
self._extra_args[name] = tnsr_like.to(self.device, non_blocking=True)
655-
elif self.use_strict_args:
656-
self._extra_args[name] = self._extra_none_inputs[name]
657574
else:
658575
self._extra_args[name] = None
659576

@@ -736,15 +653,8 @@ def nest_sequences(
736653

737654
### UPDATE EXTRA INPUTS ####################################################################
738655
self._extra_args = {}
739-
# in strict argument mode, we only accept registered extra arguments
740-
if self.use_strict_args:
741-
for name in self._extra_none_inputs.keys():
742-
self._store_extra_arg(name, extra_args.pop(name, None))
743-
assert not extra_args, f"Extra arguments {extra_args.keys()} not found"
744-
# otherwise, we simply pass in all extra arguments
745-
else:
746-
for key, value in extra_args.items():
747-
self._store_extra_arg(key, value)
656+
for key, value in extra_args.items():
657+
self._store_extra_arg(key, value)
748658

749659
@nvtx_range("ad_rescatter_input_ids")
750660
def rescatter_input_ids(
@@ -778,31 +688,6 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
778688
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
779689
return list(torch.split(t_squeezed, self.seq_len))
780690

781-
def add_extra_arg(
782-
self,
783-
name: str,
784-
none_input: torch.Tensor,
785-
dynamic_shape_callback: Optional[DynamicShapeCallback] = None,
786-
) -> None:
787-
"""Add an extra argument to the sequence info object.
788-
789-
Args:
790-
name: The name of the extra argument.
791-
none_input: None input value of the extra argument.
792-
dynamic_shape_callback: The callback to get the dynamic shape of the extra argument.
793-
794-
Note that the extra argument is expected to be a tensor.
795-
"""
796-
assert name not in self._named_args().keys(), f"Extra argument {name} already exists"
797-
798-
self._extra_args[name] = none_input.to(self.device)
799-
self._extra_none_inputs[name] = self._extra_args[name]
800-
801-
if dynamic_shape_callback is None:
802-
self._extra_dynamic_shapes_callbacks[name] = lambda: {}
803-
else:
804-
self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback
805-
806691

807692
class MHACallable(Protocol):
808693
def __call__(
@@ -814,7 +699,6 @@ def __call__(
814699
class PrepareMetadataCallable(Protocol):
815700
def __call__(
816701
self,
817-
input_ids: torch.Tensor,
818702
position_ids: torch.Tensor,
819703
seq_len: torch.Tensor,
820704
input_pos: torch.Tensor,
@@ -901,7 +785,6 @@ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
901785
902786
```
903787
def prepare_metadata(
904-
input_ids: torch.Tensor,
905788
position_ids: torch.Tensor,
906789
seq_len: torch.Tensor,
907790
input_pos: torch.Tensor,

tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def _build_conv_state_from_sequence(input_bt_c: torch.Tensor, kernel_size: int)
5454
# ---------------------------------------------------------------
5555
@torch.library.custom_op("auto_deploy::cuda_causal_conv_prepare_metadata", mutates_args=())
5656
def cuda_causal_conv_prepare_metadata(
57-
input_ids: torch.Tensor,
5857
position_ids: torch.Tensor,
5958
seq_len: torch.Tensor,
6059
input_pos: torch.Tensor,
@@ -67,7 +66,7 @@ def cuda_causal_conv_prepare_metadata(
6766
6867
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
6968
"""
70-
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
69+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
7170
num_seq = len(seq_len_sanitized)
7271

7372
seq_start = torch.zeros_like(seq_len_sanitized)
@@ -81,9 +80,9 @@ def cuda_causal_conv_prepare_metadata(
8180

8281
@cuda_causal_conv_prepare_metadata.register_fake
8382
def cuda_causal_conv_prepare_metadata_fake(
84-
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
83+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
8584
):
86-
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
85+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
8786
num_seq = len(seq_len_sanitized)
8887
return (
8988
torch.empty_like(seq_len_sanitized),

0 commit comments

Comments
 (0)