Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@
for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
__all__.append(module_name)
importlib.import_module(f"{__name__}.{module_name}")

# Recursively import subpackages and modules so their side-effects (e.g.,
# op registrations) are applied even when nested in subdirectories.
for _, full_name, _ in pkgutil.walk_packages(__path__, prefix=f"{__name__}."):
__all__.append(full_name)
importlib.import_module(full_name)
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID
from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update

from ..utils.node_utils import extract_op_args
from .attention_interface import (
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Expand Down Expand Up @@ -74,8 +74,9 @@ def cuda_causal_conv_prepare_metadata(
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)

slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)

return (seq_len_sanitized, seq_start, slot_idx_sanitized)
# This is only used during prefill to determine if we should use the initial states from the cache.
use_initial_states = input_pos > 0
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)


@cuda_causal_conv_prepare_metadata.register_fake
Expand All @@ -88,6 +89,7 @@ def cuda_causal_conv_prepare_metadata_fake(
torch.empty_like(seq_len_sanitized),
torch.empty_like(seq_len_sanitized),
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
)


Expand All @@ -101,6 +103,7 @@ def _cuda_cached_causal_conv1d(
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1]
# CONSTANTS
Expand Down Expand Up @@ -161,7 +164,7 @@ def _cuda_cached_causal_conv1d(
dim=0,
).contiguous()
cache_indices = slot_idx[:num_prefill].to(torch.int32).contiguous()
has_initial_state = torch.zeros(num_prefill, dtype=torch.bool, device=input.device)
has_initial_state = use_initial_states[:num_prefill].to(torch.bool)

# Run varlen conv; updates conv_state_cache in-place per cache_indices
y_varlen = causal_conv1d_fn(
Expand Down Expand Up @@ -215,6 +218,7 @@ def _cuda_cached_causal_conv1d_fake(
seq_len: torch.Tensor,
seq_start: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
conv_state_cache: torch.Tensor,
# CONSTANTS
Expand Down Expand Up @@ -256,8 +260,8 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns (seq_len, seq_start, slot_idx)
return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 3
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 4

@classmethod
def get_cache_initializers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from torch._ops import OpOverloadPacket
from torch.fx import Node

from ..utils.node_utils import extract_op_args
from .attention_interface import (
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Expand Down Expand Up @@ -160,8 +160,8 @@ def torch_causal_conv_prepare_metadata(
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)

slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)

return (seq_len_sanitized, seq_start, slot_idx_sanitized)
use_initial_states = input_pos > 0
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)


@torch_causal_conv_prepare_metadata.register_fake
Expand All @@ -174,6 +174,7 @@ def torch_causal_conv_prepare_metadata_fake(
torch.empty_like(seq_len_sanitized),
torch.empty_like(seq_len_sanitized),
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
)


Expand All @@ -187,6 +188,7 @@ def _torch_cached_causal_conv1d(
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k]
# CONSTANTS
Expand Down Expand Up @@ -275,6 +277,7 @@ def _torch_cached_causal_conv1d_fake(
seq_len: torch.Tensor,
seq_start: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
conv_state_cache: torch.Tensor,
# CONSTANTS
Expand Down Expand Up @@ -317,8 +320,10 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
# reference implementation to support chunked prefill.
# Returns (seq_len, seq_start, slot_idx)
return torch.ops.auto_deploy.torch_causal_conv_prepare_metadata, 3
return torch.ops.auto_deploy.torch_causal_conv_prepare_metadata, 4

@classmethod
def get_cache_initializers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch._ops import OpOverloadPacket
from torch.fx import Node

from ..utils.node_utils import extract_op_args
from .attention_interface import (
from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
Expand All @@ -25,10 +25,10 @@
PrepareMetadataCallable,
SequenceInfo,
)
from .torch_mamba import _torch_ssm_transform_prefill
from .torch_mamba import _torch_ssm_prefill


def _torch_cached_ssm_transform_decode(
def _torch_cached_ssm_decode(
hidden_states: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
Expand Down Expand Up @@ -135,8 +135,10 @@ def _torch_ssm_prepare_metadata(

# Truncate slot indices to match active sequences
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)

return (seq_len_sanitized, seq_start, slot_idx_sanitized)
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
# reference implementation to support chunked prefill.
use_initial_states = input_pos > 0
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)


@_torch_ssm_prepare_metadata.register_fake
Expand All @@ -150,11 +152,12 @@ def _torch_ssm_prepare_metadata_fake(
torch.empty_like(seq_len_sanitized),
torch.empty_like(seq_len_sanitized),
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
)


@torch.library.custom_op("auto_deploy::torch_cached_ssm_transform", mutates_args={})
def _torch_cached_ssm_transform(
@torch.library.custom_op("auto_deploy::torch_cached_ssm", mutates_args={})
def _torch_cached_ssm(
# INPUTS (dense but may be flattened across sequences)
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
A: torch.Tensor, # [num_heads]
Expand All @@ -167,6 +170,7 @@ def _torch_cached_ssm_transform(
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
# CACHES
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
# CONSTANTS
Expand All @@ -188,7 +192,7 @@ def _torch_cached_ssm_transform(
slot_idx_long = slot_idx.to(torch.long)
ssm_batch = ssm_state_cache.index_select(dim=0, index=slot_idx_long)

y, updated_state = _torch_cached_ssm_transform_decode(
y, updated_state = _torch_cached_ssm_decode(
hidden_states,
A,
B,
Expand Down Expand Up @@ -244,7 +248,7 @@ def _torch_cached_ssm_transform(
dt_seq = dt_flat.index_select(0, idx_i).unsqueeze(0)

# Run prefill and obtain final SSM state for this sequence
y_seq, ssm_state_i = _torch_ssm_transform_prefill(
y_seq, ssm_state_i = _torch_ssm_prefill(
hs_seq, A, B_seq, C_seq, D, dt_seq, dt_bias, time_step_limit, chunk_size
)

Expand All @@ -258,8 +262,8 @@ def _torch_cached_ssm_transform(
return y


@_torch_cached_ssm_transform.register_fake
def _torch_cached_ssm_transform_fake(
@_torch_cached_ssm.register_fake
def _torch_cached_ssm_fake(
# INPUTS
hidden_states: torch.Tensor,
A: torch.Tensor,
Expand All @@ -272,6 +276,7 @@ def _torch_cached_ssm_transform_fake(
seq_len: torch.Tensor,
seq_start: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
# CACHES
ssm_state_cache: torch.Tensor,
# CONSTANTS
Expand Down Expand Up @@ -304,16 +309,16 @@ def get_num_qkv_args(cls) -> int:

@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.auto_deploy.torch_ssm_transform
return torch.ops.auto_deploy.torch_ssm

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.torch_cached_ssm_transform
return torch.ops.auto_deploy.torch_cached_ssm

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns (seq_len, seq_start, slot_idx)
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 3
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4

@classmethod
def get_cache_initializers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _segment_sum(input_tensor):
return tensor_segsum


def _torch_ssm_transform_prefill(
def _torch_ssm_prefill(
hidden_states: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
Expand Down Expand Up @@ -162,8 +162,8 @@ def _torch_ssm_transform_prefill(
return y, ssm_state


@torch.library.custom_op("auto_deploy::torch_ssm_transform", mutates_args={})
def _torch_ssm_transform(
@torch.library.custom_op("auto_deploy::torch_ssm", mutates_args={})
def _torch_ssm(
hidden_states: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
Expand All @@ -176,14 +176,12 @@ def _torch_ssm_transform(
], # NOTE: `torch` custom ops do not like `Tuple` inputs. Using `List` is the suggested WAR.
chunk_size: int,
) -> torch.Tensor:
y, _ = _torch_ssm_transform_prefill(
hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size
)
y, _ = _torch_ssm_prefill(hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size)
return y


@_torch_ssm_transform.register_fake
def _torch_ssm_transform_meta(
@_torch_ssm.register_fake
def _torch_ssm_meta(
hidden_states: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
Expand Down
Loading