Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@
for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
__all__.append(module_name)
importlib.import_module(f"{__name__}.{module_name}")

# Recursively import subpackages and modules so their side-effects (e.g.,
# op registrations) are applied even when nested in subdirectories.
for _, full_name, _ in pkgutil.walk_packages(__path__, prefix=f"{__name__}."):
__all__.append(full_name)
importlib.import_module(full_name)
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update

from ..utils.node_utils import extract_op_args
from .attention_interface import (
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Expand Down Expand Up @@ -74,8 +74,9 @@ def cuda_causal_conv_prepare_metadata(
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)

slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)

return (seq_len_sanitized, seq_start, slot_idx_sanitized)
# This is only used during prefill to determine if we should use the initial states from the cache.
use_initial_states = input_pos > 0
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)


@cuda_causal_conv_prepare_metadata.register_fake
Expand All @@ -88,6 +89,7 @@ def cuda_causal_conv_prepare_metadata_fake(
torch.empty_like(seq_len_sanitized),
torch.empty_like(seq_len_sanitized),
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
)


Expand All @@ -101,6 +103,7 @@ def _cuda_cached_causal_conv1d(
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1]
# CONSTANTS
Expand Down Expand Up @@ -161,7 +164,7 @@ def _cuda_cached_causal_conv1d(
dim=0,
).contiguous()
cache_indices = slot_idx[:num_prefill].to(torch.int32).contiguous()
has_initial_state = torch.zeros(num_prefill, dtype=torch.bool, device=input.device)
has_initial_state = use_initial_states[:num_prefill].to(torch.bool)

# Run varlen conv; updates conv_state_cache in-place per cache_indices
y_varlen = causal_conv1d_fn(
Expand Down Expand Up @@ -215,6 +218,7 @@ def _cuda_cached_causal_conv1d_fake(
seq_len: torch.Tensor,
seq_start: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
conv_state_cache: torch.Tensor,
# CONSTANTS
Expand Down Expand Up @@ -256,8 +260,8 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns (seq_len, seq_start, slot_idx)
return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 3
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 4

@classmethod
def get_cache_initializers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from torch._ops import OpOverloadPacket
from torch.fx import Node

from ..utils.node_utils import extract_op_args
from .attention_interface import (
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Expand Down Expand Up @@ -317,6 +317,8 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
# reference implementation to support chunked prefill.
# Returns (seq_len, seq_start, slot_idx)
return torch.ops.auto_deploy.torch_causal_conv_prepare_metadata, 3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch._ops import OpOverloadPacket
from torch.fx import Node

from ..utils.node_utils import extract_op_args
from .attention_interface import (
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Expand All @@ -25,10 +25,10 @@
PrepareMetadataCallable,
SequenceInfo,
)
from .torch_mamba import _torch_ssm_transform_prefill
from .torch_mamba import _torch_ssm_prefill


def _torch_cached_ssm_transform_decode(
def _torch_cached_ssm_decode(
hidden_states: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
Expand Down Expand Up @@ -136,7 +136,11 @@ def _torch_ssm_prepare_metadata(
# Truncate slot indices to match active sequences
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)

return (seq_len_sanitized, seq_start, slot_idx_sanitized)
# Determine whether to use initial states.
# This is only used during prefill to determine if we should use the initial states from the cache.
use_initial_states = input_pos > 0

return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)


@_torch_ssm_prepare_metadata.register_fake
Expand All @@ -150,11 +154,12 @@ def _torch_ssm_prepare_metadata_fake(
torch.empty_like(seq_len_sanitized),
torch.empty_like(seq_len_sanitized),
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
)


@torch.library.custom_op("auto_deploy::torch_cached_ssm_transform", mutates_args={})
def _torch_cached_ssm_transform(
@torch.library.custom_op("auto_deploy::torch_cached_ssm", mutates_args={})
def _torch_cached_ssm(
# INPUTS (dense but may be flattened across sequences)
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
A: torch.Tensor, # [num_heads]
Expand Down Expand Up @@ -188,7 +193,7 @@ def _torch_cached_ssm_transform(
slot_idx_long = slot_idx.to(torch.long)
ssm_batch = ssm_state_cache.index_select(dim=0, index=slot_idx_long)

y, updated_state = _torch_cached_ssm_transform_decode(
y, updated_state = _torch_cached_ssm_decode(
hidden_states,
A,
B,
Expand Down Expand Up @@ -244,7 +249,7 @@ def _torch_cached_ssm_transform(
dt_seq = dt_flat.index_select(0, idx_i).unsqueeze(0)

# Run prefill and obtain final SSM state for this sequence
y_seq, ssm_state_i = _torch_ssm_transform_prefill(
y_seq, ssm_state_i = _torch_ssm_prefill(
hs_seq, A, B_seq, C_seq, D, dt_seq, dt_bias, time_step_limit, chunk_size
)

Expand All @@ -258,8 +263,8 @@ def _torch_cached_ssm_transform(
return y


@_torch_cached_ssm_transform.register_fake
def _torch_cached_ssm_transform_fake(
@_torch_cached_ssm.register_fake
def _torch_cached_ssm_fake(
# INPUTS
hidden_states: torch.Tensor,
A: torch.Tensor,
Expand Down Expand Up @@ -304,16 +309,16 @@ def get_num_qkv_args(cls) -> int:

@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.auto_deploy.torch_ssm_transform
return torch.ops.auto_deploy.torch_ssm_prefill

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.torch_cached_ssm_transform
return torch.ops.auto_deploy.torch_cached_ssm

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns (seq_len, seq_start, slot_idx)
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 3
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4

@classmethod
def get_cache_initializers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _segment_sum(input_tensor):
return tensor_segsum


def _torch_ssm_transform_prefill(
def _torch_ssm_prefill(
hidden_states: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
Expand Down Expand Up @@ -162,8 +162,8 @@ def _torch_ssm_transform_prefill(
return y, ssm_state


@torch.library.custom_op("auto_deploy::torch_ssm_transform", mutates_args={})
def _torch_ssm_transform(
@torch.library.custom_op("auto_deploy::torch_ssm", mutates_args={})
def _torch_ssm(
hidden_states: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
Expand All @@ -176,14 +176,12 @@ def _torch_ssm_transform(
], # NOTE: `torch` custom ops do not like `Tuple` inputs. Using `List` is the suggested WAR.
chunk_size: int,
) -> torch.Tensor:
y, _ = _torch_ssm_transform_prefill(
hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size
)
y, _ = _torch_ssm_prefill(hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size)
return y


@_torch_ssm_transform.register_fake
def _torch_ssm_transform_meta(
@_torch_ssm.register_fake
def _torch_ssm_meta(
hidden_states: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from torch.fx import Node

# Triton kernels
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets
from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined

from ..utils.node_utils import extract_op_args
from .attention_interface import (
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Expand All @@ -23,8 +24,8 @@
)


@torch.library.custom_op("auto_deploy::triton_cached_ssm_transform", mutates_args={})
def _triton_cached_ssm_transform(
@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={})
def _triton_cached_ssm(
# INPUTS (dense but may be flattened across sequences)
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
A: torch.Tensor, # [num_heads]
Expand All @@ -37,6 +38,7 @@ def _triton_cached_ssm_transform(
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
# CONSTANTS
Expand All @@ -51,7 +53,6 @@ def _triton_cached_ssm_transform(
"""
b, s = hidden_states.shape[:2]
num_seq = seq_len.shape[0]

# Flatten tokens for indexing/scatter
bs = b * s
device = hidden_states.device
Expand Down Expand Up @@ -96,6 +97,16 @@ def _triton_cached_ssm_transform(
seq_ids = torch.arange(num_prefill, device=device, dtype=torch.int32)
seq_idx_prefill = torch.repeat_interleave(seq_ids, seq_len_prefill).view(1, -1)

initial_states = chunk_indices = chunk_offsets = None
if torch.any(use_initial_states[:num_prefill]):
initial_states = torch.where(
use_initial_states[:num_prefill, None, None, None],
ssm_state_cache[slot_idx[:num_prefill]],
0,
)
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
cu_seqlens, chunk_size
)
y_prefill, varlen_states = mamba_chunk_scan_combined(
hs_prefill,
dt_prefill,
Expand All @@ -106,10 +117,10 @@ def _triton_cached_ssm_transform(
D=D,
z=None,
dt_bias=dt_bias,
initial_states=None,
initial_states=initial_states,
seq_idx=seq_idx_prefill,
chunk_indices=None,
chunk_offsets=None,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=True,
dt_limit=(time_step_limit[0], time_step_limit[1]),
Expand Down Expand Up @@ -159,8 +170,8 @@ def _triton_cached_ssm_transform(
return y


@_triton_cached_ssm_transform.register_fake
def _triton_cached_ssm_transform_fake(
@_triton_cached_ssm.register_fake
def _triton_cached_ssm_fake(
# INPUTS (dense but may be flattened across sequences)
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
A: torch.Tensor, # [num_heads]
Expand All @@ -173,6 +184,7 @@ def _triton_cached_ssm_transform_fake(
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
# CONSTANTS
Expand Down Expand Up @@ -209,16 +221,16 @@ def get_num_qkv_args(cls) -> int:
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
# Keep source op unchanged (used for uncached pre-export)
return torch.ops.auto_deploy.torch_ssm_transform
return torch.ops.auto_deploy.torch_ssm

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.triton_cached_ssm_transform
return torch.ops.auto_deploy.triton_cached_ssm

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns (seq_len, seq_start, slot_idx)
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 3
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4

@classmethod
def get_cache_initializers(
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _bamba_mixer_torch_forward(

if use_caching:
# Use new flattened cached op for both cache updates and outputs
y = torch.ops.auto_deploy.torch_cached_ssm_transform(
y = torch.ops.auto_deploy.torch_cached_ssm(
# INPUTS
hidden_states=hidden_states.view(batch_size, seq_len, -1, self.head_dim),
A=A,
Expand All @@ -120,7 +120,7 @@ def _bamba_mixer_torch_forward(
chunk_size=self.chunk_size,
)
else:
y = torch.ops.auto_deploy.torch_ssm_transform(
y = torch.ops.auto_deploy.torch_ssm(
hidden_states=hidden_states.view(batch_size, seq_len, -1, self.head_dim),
A=A,
B=B.view(batch_size, seq_len, -1, self.ssm_state_size),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,5 +326,4 @@ def _apply_to_full_model(
info = TransformInfo(
skipped=False, num_matches=num_caches, is_clean=True, has_valid_shapes=True
)

return mod, info
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_default_kwargs(self, enable_chunked_prefill=False):
"free_mem_ratio": 0.7
},
"compile_model": {
"backend": "torch-opt",
"backend": "torch-cudagraph",
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
},
},
Expand All @@ -140,9 +140,6 @@ def get_default_sampling_params(self):
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_auto_dtype(self, enable_chunked_prefill):
if enable_chunked_prefill:
pytest.skip(
"see https://github.com/NVIDIA/TensorRT-LLM/issues/8272")
kwargs = self.get_default_kwargs(enable_chunked_prefill)
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH,
Expand Down
Loading