Skip to content

Commit dcfd3ef

Browse files
lucaslie2ez4bznvchenghaozFridah-nvsuyoggupta
authored
[#4593][feat] AutoDeploy: Linear Attention Support (SSM + causal_conv + Bamba + Nemotron-H) (#8068)
Signed-off-by: William Zhang <[email protected]> Signed-off-by: Lucas Liebenwein <[email protected]> Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: Frida Hou <[email protected]> Signed-off-by: Suyog Gupta <[email protected]> Co-authored-by: William Zhang <[email protected]> Co-authored-by: Chenghao Zhang <[email protected]> Co-authored-by: Frida Hou <[email protected]> Co-authored-by: Suyog Gupta <[email protected]>
1 parent 62010c0 commit dcfd3ef

34 files changed

+3094
-39
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ transforms:
127127
insert_cached_mla_attention:
128128
stage: cache_init
129129
attn_backend: MultiHeadLatentAttention
130+
insert_cached_ssm_attention:
131+
stage: cache_init
132+
attn_backend: triton_ssm
133+
insert_cached_causal_conv:
134+
stage: cache_init
135+
attn_backend: cuda_causal_conv
130136
initialize_cache:
131137
stage: cache_init
132138
resize_kv_cache:
Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,10 @@
11
"""Custom ops and make sure they are all registered."""
22

3-
from ._triton_attention_internal import *
4-
from .dist import *
5-
from .flashinfer_attention import *
6-
from .flashinfer_rope import *
7-
from .linear import *
8-
from .mla import *
9-
from .mxfp4_moe import *
10-
from .quant import *
11-
from .rms_norm import *
12-
from .torch_attention import *
13-
from .torch_backend_attention import *
14-
from .torch_moe import *
15-
from .torch_quant import *
16-
from .torch_rope import *
17-
from .torch_router import *
18-
from .triton_attention import *
19-
from .triton_rope import *
20-
from .trtllm_moe import *
3+
import importlib
4+
import pkgutil
5+
6+
__all__ = []
7+
8+
for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
9+
__all__.append(module_name)
10+
importlib.import_module(f"{__name__}.{module_name}")

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class SequenceInfo:
7272
- pages_per_seq: [ps_0, ps_1, ..., ps_{b-1}] where ps_i is the number of pages allocated for
7373
sequence i. Note that, for example, cache_loc[p_0:p_1] will correspond to the pages associated
7474
with sequence 1 in the batch.
75+
- slot_idx: [s_0, s_1, ..., s_{b-1}]
76+
Corresponds to the slot index of each sequence in the batch.
7577
7678
################################################################################################
7779
@@ -134,7 +136,8 @@ def __init__(
134136
self._num_pages = max(
135137
self.max_batch_size,
136138
(self.max_num_tokens) // self.page_size # floored number of pages
137-
+ (self.max_num_tokens % self.page_size > 0) * self.max_batch_size, # +1 per sequence
139+
+ (self.max_num_tokens / self.max_batch_size % self.page_size > 0) # check for overflow
140+
* self.max_batch_size, # +1 page per sequence if overflow is required
138141
)
139142
# sanity check
140143
assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"
@@ -164,6 +167,7 @@ def __init__(
164167
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
165168
"cache_loc": torch.empty(self.num_pages, dtype=torch.int),
166169
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
170+
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.int),
167171
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
168172
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
169173
}
@@ -172,7 +176,8 @@ def __init__(
172176
}
173177
# NOTE: order of keys is relevant here!
174178
self._uncached_arg_names = ("input_ids", "position_ids")
175-
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq")
179+
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx")
180+
self._cached_constants = ("page_size",)
176181
############################################################################################
177182

178183
# EXTRA TENSOR FIELDS ######################################################################
@@ -296,7 +301,7 @@ def const_args_for_prepare_metadata(self) -> Tuple:
296301
``insert_cached_attention`` to extract the constant arguments and add them to the
297302
``prepare_metadata`` node/op.
298303
"""
299-
return (self.page_size,)
304+
return tuple(getattr(self, k) for k in self._cached_constants)
300305

301306
@property
302307
def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]:
@@ -311,6 +316,7 @@ def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]:
311316
if self.max_batch_size > 1:
312317
bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size)
313318
bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len)
319+
# bs_seq_len_shape[1] = Dim.AUTO
314320
self._dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names}
315321
# cached args are static
316322
self._dynamic_shapes.update({k: {} for k in self._cached_arg_names})
@@ -522,11 +528,15 @@ def set_example_sequence(
522528
cache_loc = list(range(sum(pages_per_seq)))
523529
page_assignments = self._get_page_assignments(cache_loc, pages_per_seq)
524530

531+
# vanilla slot indices
532+
slot_idx = list(range(len(input_ids)))
533+
525534
self.nest_sequences(
526535
input_ids,
527536
position_ids, # will be auto-inferred if None
528537
input_pos=0, # no cache history
529538
page_assignments=page_assignments, # vanilla page assignments
539+
slot_idx=slot_idx, # vanilla slot indices
530540
**extra_args,
531541
)
532542

@@ -613,6 +623,7 @@ def nest_sequences(
613623
position_ids: Optional[Sequence[Sequence[int]]] = None,
614624
input_pos: Optional[Union[Sequence[int], int]] = None,
615625
page_assignments: Optional[Sequence[Sequence[int]]] = None,
626+
slot_idx: Optional[Sequence[int]] = None,
616627
**extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]],
617628
) -> None:
618629
"""Create and store sequence information for the next forward pass.
@@ -622,6 +633,7 @@ def nest_sequences(
622633
position_ids: List of sequences of position_ids for each token.
623634
input_pos: Absolute starting position in the cache for each sequence.
624635
page_assignments: List of sequences of page assignments for each sequence.
636+
slot_idx: List of slot indices for each sequence.
625637
extra_args: Extra arguments to be stored in the interface.
626638
627639
This i/f will ensure that all sequence info args are updated accordingly.
@@ -648,6 +660,10 @@ def nest_sequences(
648660
self._store_arg("cache_loc", cache_loc, reset=True)
649661
self._store_arg("pages_per_seq", pages_per_seq, reset=True)
650662

663+
# check for updated slot_idx
664+
if slot_idx is not None:
665+
self._store_arg("slot_idx", slot_idx)
666+
651667
### UPDATE MAIN INPUTS #####################################################################
652668
# set new input_ids and make sure to flatten it
653669
self._store_arg("input_ids", self._flatten(input_ids))
@@ -749,6 +765,7 @@ def __call__(
749765
input_pos: torch.Tensor,
750766
cache_loc: torch.Tensor,
751767
pages_per_seq: torch.Tensor,
768+
slot_idx: torch.Tensor,
752769
page_size: int,
753770
) -> List[torch.Tensor]: ...
754771

@@ -834,6 +851,9 @@ def prepare_metadata(
834851
seq_len: torch.Tensor,
835852
input_pos: torch.Tensor,
836853
cache_loc: torch.Tensor,
854+
pages_per_seq: torch.Tensor,
855+
slot_idx: torch.Tensor,
856+
page_size: int,
837857
) -> List[torch.Tensor]: ...
838858
```
839859
The metadata should contain all necessary global information required for the underlying

0 commit comments

Comments
 (0)