diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 352e93c5ef5..d368acdd1ed 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -123,10 +123,10 @@ transforms: attn_backend: MultiHeadLatentAttention insert_cached_ssm_attention: stage: cache_init - attn_backend: torch_ssm + attn_backend: triton_ssm insert_cached_causal_conv: stage: cache_init - attn_backend: torch_causal_conv + attn_backend: cuda_causal_conv initialize_cache: stage: cache_init resize_kv_cache: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 16dd1bf4fe3..150e449ef5c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -28,6 +28,7 @@ class CacheConfig: """A dataclass to hold information how to configure the cache.""" + # dtype of the cache dtype: Optional[torch.dtype] = None @@ -522,6 +523,7 @@ def set_example_sequence( # vanilla slot indices slot_idx = list(range(len(input_ids))) + # breakpoint() self.nest_sequences( input_ids, @@ -537,6 +539,9 @@ def set_max_num_tokens_sample(self) -> None: # TODO (lucaslie): understand what this implies for extra arguments seq_len = self.max_num_tokens // self.max_batch_size input_ids = torch.ones(self.max_batch_size, seq_len, dtype=torch.int).tolist() + print( + f"setting max_num_tokens_sample: {self.max_num_tokens=}, {self.max_batch_size=}, {seq_len=}" + ) self.set_example_sequence(input_ids) def set_generate_only_batch(self) -> None: @@ -581,6 +586,10 @@ def _store_arg( # pin the memory on the host tnsr_host = torch.tensor(tnsr_like, dtype=tnsr_device.dtype, pin_memory=True) + if tnsr_device.numel() < tnsr_host.numel(): + print("WARNING: tnsr_device.numel() < tnsr_like.numel()") + print(f"{name=}, {tnsr_device.numel()=}, {tnsr_host.numel()=}") + tnsr_device.resize_(tnsr_host.numel()) # reset/copy to the device in a non-blocking fashion if reset: tnsr_device.zero_() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py new file mode 100644 index 00000000000..cc4397fa76d --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py @@ -0,0 +1,365 @@ +"""CUDA-backed cached causal conv1d custom ops and attention descriptor. + +This mirrors `torch_backend_causal_conv.py` but reuses existing TRT-LLM CUDA +operators for performance: +- Prefill uses `torch.ops.trtllm.causal_conv1d_fwd` +- Decode uses `torch.ops.trtllm.causal_conv1d_update` + +The flattened cached op integrates with the auto_deploy attention interface +and updates a slot-indexed convolution state cache internally. +""" + +from typing import List, Optional, Tuple + +import torch +from torch._ops import OpOverloadPacket +from torch.fx import Node + +from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID + +from ..utils.node_utils import extract_op_args +from .attention_interface import ( + AttentionDescriptor, + AttentionLayout, + AttentionRegistry, + BufferInitializerDict, + CacheConfig, + CacheInitializerDict, + Constant, + MHACallable, + PrepareMetadataCallable, + SequenceInfo, +) + + +def _build_conv_state_from_sequence(input_bt_c: torch.Tensor, kernel_size: int) -> torch.Tensor: + """Builds a convolution state of fixed window `kernel_size` from a sequence. + + input_bt_c: [B, T, C] + Returns: [B, C, K] + """ + # [B, T, C] -> [B, C, T] + input_b_c_t = input_bt_c.transpose(1, 2) + seq_len = input_b_c_t.shape[-1] + if seq_len >= kernel_size: + return input_b_c_t[..., -kernel_size:] + pad_amount = kernel_size - seq_len + # F.pad last dim (time) with (pad_left, pad_right) + return torch.nn.functional.pad(input_b_c_t, (pad_amount, 0)) + + +def _cuda_causal_conv1d_prefill( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + stride: int, + padding: int, + dilation: int, + groups: int, + padding_mode: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Prefill path using TRT-LLM forward kernel; returns (y, conv_state[K-1]).""" + assert padding_mode == "zeros", "padding_mode must be zeros" + # Shapes: convert input to [B, C, T] + x_b_c_t = input.transpose(1, 2).contiguous() + k = weight.shape[-1] + # Weight to [C, K] + w2d = weight.squeeze(1) if weight.ndim == 3 else weight + w2d = w2d.contiguous() + # Initialize state [B, C, K-1] to zeros + conv_state = torch.zeros( + x_b_c_t.shape[0], x_b_c_t.shape[1], k - 1, device=x_b_c_t.device, dtype=x_b_c_t.dtype + ) + # Run TRT forward (in-place on x_b_c_t and conv_state) + torch.ops.trtllm.causal_conv1d_fwd( + x_b_c_t, w2d, bias, conv_state, None, None, None, False, PAD_SLOT_ID + ) + y = x_b_c_t.transpose(1, 2) + return y, conv_state + + +def _cuda_causal_conv1d_decode( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + stride: int, + padding: int, + dilation: int, + groups: int, + padding_mode: str, + conv_state_cache: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Decode path using TRT-LLM update kernel for last-step output and cache update. + + Returns (y, updated_conv_state) where y: [B, 1, C_out] and updated state: [B, C_in, K]. + """ + assert padding_mode == "zeros", "padding_mode must be zeros" + # For cached decode we currently support stride=1 and dilation=1 (standard causal conv) + assert stride == 1, "cached causal conv1d currently supports stride == 1 only" + assert dilation == 1, "cached causal conv1d currently supports dilation == 1 only" + + batch_size, seq_len, _ = input.shape + assert seq_len == 1, "decode path expects seq_len == 1" + + kernel_size = weight.shape[-1] + # TRT update expects state len >= K-1 + assert conv_state_cache.shape[-1] >= kernel_size - 1, ( + "conv_state_cache's last dim must be >= kernel_size - 1" + ) + + # TRT-LLM update kernel expects depthwise form: weight [dim, width], groups == dim + in_channels = input.shape[-1] + assert groups == in_channels, ( + "cuda cached causal conv decode currently supports depthwise conv with groups == in_channels" + ) + # Convert weight to [dim, width] + if weight.ndim == 3: + # Expect [C_out, C_in/groups, K] with C_in/groups == 1 + assert weight.shape[-2] == 1 and weight.shape[0] == in_channels, ( + "expected depthwise weight with shape [C, 1, K] matching input channels" + ) + weight_2d = weight.squeeze(-2) + elif weight.ndim == 2: + weight_2d = weight + assert weight_2d.shape[0] == in_channels, ( + "weight rows must match input channels for depthwise conv" + ) + else: + raise AssertionError("unsupported weight rank for causal conv update; expected 2D or 3D") + + # Prepare buffers for TRT-LLM update kernel call. + # TRT-LLM update kernel signature (Python): + # torch.ops.trtllm.causal_conv1d_update(x, conv_state, weight, bias, + # activation_val, cache_seqlens, + # conv_state_indices, pad_slot_id) + # We set activation to None and other optional args to None to get a plain linear conv. + # Convert input to [B, C, T] + x_b_c_t = input.transpose(1, 2).contiguous() + updated_cache = conv_state_cache.clone() + torch.ops.trtllm.causal_conv1d_update( + x_b_c_t, updated_cache, weight_2d, bias, False, None, None, PAD_SLOT_ID + ) + y = x_b_c_t.transpose(1, 2) + return y, updated_cache + + +# --------------------------------------------------------------- +# Metadata + flattened cached op that integrates with the AD i/f +# --------------------------------------------------------------- + + +@torch.library.custom_op("auto_deploy::cuda_causal_conv_prepare_metadata", mutates_args=()) +def cuda_causal_conv_prepare_metadata( + input_ids: torch.Tensor, + position_ids: torch.Tensor, + seq_len: torch.Tensor, + input_pos: torch.Tensor, + cache_loc: torch.Tensor, + pages_per_seq: torch.Tensor, + slot_idx: torch.Tensor, + page_size: int, +) -> List[torch.Tensor]: + """Prepare metadata for cached causal conv (CUDA backend). + + Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). + """ + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + num_seq = len(seq_len_sanitized) + + seq_start = torch.zeros_like(seq_len_sanitized) + if num_seq > 1: + seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0) + + slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long) + + return (seq_len_sanitized, seq_start, slot_idx_sanitized) + + +@cuda_causal_conv_prepare_metadata.register_fake +def cuda_causal_conv_prepare_metadata_fake( + input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size +): + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + num_seq = len(seq_len_sanitized) + return ( + torch.empty_like(seq_len_sanitized), + torch.empty_like(seq_len_sanitized), + torch.empty(num_seq, dtype=torch.long, device=slot_idx.device), + ) + + +@torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={}) +def _cuda_cached_causal_conv1d( + # INPUTS (dense but may be flattened across sequences) + input: torch.Tensor, # [b, s, c_in] + weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k] + bias: Optional[torch.Tensor], + # METADATA + seq_len: torch.Tensor, # [num_seq] + seq_start: torch.Tensor, # [num_seq] + slot_idx: torch.Tensor, # [num_seq] + # CACHES + conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1] + # CONSTANTS + stride: int, + padding: int, + dilation: int, + groups: int, + padding_mode: str, +) -> torch.Tensor: + """Flattened cached causal conv that respects slot-indexed state caches (CUDA backend). + + Supports two layouts from the attention interface: + - Generate-only: input is [b, 1, c_in]. We'll gather caches using slot_idx[:b]. + - Flattened context/mixed: input is [1, total_s, c_in] and seq_len/seq_start + describe per-sequence segments. We'll process each segment and scatter final states to caches. + """ + b, s = input.shape[:2] + num_seq = seq_len.shape[0] + + if s == 1: + # Generate-only batch + slot_idx_long = slot_idx.to(torch.long) + cache_batch = conv_state_cache.index_select(0, slot_idx_long) + + y, updated_state = _cuda_causal_conv1d_decode( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + padding_mode, + cache_batch, + ) + + conv_state_cache.index_copy_(0, slot_idx_long, updated_state.to(conv_state_cache.dtype)) + # Custom op must not return an alias of any input; return a fresh tensor + return y.to(input.dtype).contiguous().clone() + + # Context/mixed phase (flattened sequences) + bs = b * s + flat_idx = torch.arange(bs, device=input.device, dtype=torch.long) + + inp_flat = input.reshape(bs, *input.shape[2:]) + y = torch.empty(b, s, weight.shape[0], device=input.device, dtype=input.dtype) + y_flat = y.view(bs, *y.shape[2:]) + + for i in range(num_seq): + length_i = seq_len[i] + if length_i.eq(0): + continue + start_i = seq_start[i] + end_i = start_i + length_i + + mask_i = (flat_idx >= start_i.to(torch.long)) & (flat_idx < end_i.to(torch.long)) + idx_i = torch.nonzero(mask_i, as_tuple=False).squeeze(-1) + + inp_seq = inp_flat.index_select(0, idx_i).unsqueeze(0) + + y_seq, conv_state_i = _cuda_causal_conv1d_prefill( + inp_seq, + weight, + bias, + stride, + padding, + dilation, + groups, + padding_mode, + ) + + y_flat.index_copy_(0, idx_i, y_seq[0].to(y_flat.dtype)) + + slot_i = slot_idx[i].to(torch.long).unsqueeze(0) + conv_state_cache.index_copy_(0, slot_i, conv_state_i.to(conv_state_cache.dtype)) + + # Custom op must not return an alias of any input; return a fresh tensor + return y.contiguous().clone() + + +@_cuda_cached_causal_conv1d.register_fake +def _cuda_cached_causal_conv1d_fake( + # INPUTS + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + # METADATA + seq_len: torch.Tensor, + seq_start: torch.Tensor, + slot_idx: torch.Tensor, + # CACHES + conv_state_cache: torch.Tensor, + # CONSTANTS + stride: int, + padding: int, + dilation: int, + groups: int, + padding_mode: str, +): + return torch.empty( + input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype + ) + + +@AttentionRegistry.register("cuda_causal_conv") +class CudaBackendCausalConv(AttentionDescriptor): + @classmethod + def is_paged(cls) -> bool: + return True + + @classmethod + def get_attention_layout(cls) -> AttentionLayout: + # Hidden states follow [b, s, c] + return "bsnd" + + @classmethod + def get_num_qkv_args(cls) -> int: + # torch_causal_conv1d signature has 3 relevant tensor arguments + # TODO: bias can be optional!! How to handle None bias here? + return 3 + + @classmethod + def get_source_attention_op(cls) -> OpOverloadPacket: + return torch.ops.auto_deploy.torch_causal_conv1d + + @classmethod + def get_cached_attention_op(cls) -> MHACallable: + return torch.ops.auto_deploy.cuda_cached_causal_conv1d + + @classmethod + def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: + # Returns (seq_len, seq_start, slot_idx) + return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 3 + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: CacheConfig + ) -> CacheInitializerDict: + inp_fake: torch.Tensor = source_attn_node.args[0].meta["val"] + w_fake: torch.Tensor = source_attn_node.args[1].meta["val"] + + in_channels = inp_fake.shape[-1] + kernel_size = w_fake.shape[-1] + + def _get_conv_cache(si: SequenceInfo): + return torch.empty( + si.max_batch_size, + in_channels, + max(1, kernel_size - 1), + device=si.device, + dtype=cache_config.dtype or inp_fake.dtype, + ) + + return {"conv_state_cache": _get_conv_cache} + + @classmethod + def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: + return {} + + @classmethod + def get_constants(cls, source_attn_node: Node) -> List[Constant]: + stride, padding, dilation, groups, padding_mode = extract_op_args( + source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" + ) + return [stride, padding, dilation, groups, padding_mode] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py index ffc0aef06f8..4ac148e815e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py @@ -282,6 +282,7 @@ def _torch_cached_ssm_transform_fake( return torch.empty_like( hidden_states, memory_format=torch.contiguous_format, + dtype=torch.float32, ) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py index c346863903f..7deeced93eb 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py @@ -194,4 +194,4 @@ def _torch_ssm_transform_meta( time_step_limit: List[float], chunk_size: int, ) -> torch.Tensor: - return torch.empty_like(hidden_states) + return torch.empty_like(hidden_states, dtype=torch.float32) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py new file mode 100644 index 00000000000..536ecfc896c --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py @@ -0,0 +1,253 @@ +from typing import List, Tuple + +import torch +from torch._ops import OpOverloadPacket +from torch.fx import Node + +# Triton kernels +from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update +from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined + +from ..utils.node_utils import extract_op_args +from .attention_interface import ( + AttentionDescriptor, + AttentionLayout, + AttentionRegistry, + BufferInitializerDict, + CacheConfig, + CacheInitializerDict, + Constant, + MHACallable, + PrepareMetadataCallable, + SequenceInfo, +) + + +@torch.library.custom_op("auto_deploy::triton_cached_ssm_transform", mutates_args={}) +def _triton_cached_ssm_transform( + # INPUTS (dense but may be flattened across sequences) + hidden_states: torch.Tensor, # [b, s, num_heads, head_dim] + A: torch.Tensor, # [num_heads] + B: torch.Tensor, # [b, s, n_groups, ssm_state_size] + C: torch.Tensor, # [b, s, n_groups, ssm_state_size] + D: torch.Tensor, # [num_heads] + dt: torch.Tensor, # [b, s, num_heads] + dt_bias: torch.Tensor, # [num_heads] + # METADATA + seq_len: torch.Tensor, # [num_seq] + seq_start: torch.Tensor, # [num_seq] + slot_idx: torch.Tensor, # [num_seq] + # CACHES + ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] + # CONSTANTS + time_step_limit: List[float], + chunk_size: int, +) -> torch.Tensor: + """Flattened cached SSM transform op that respects slot-indexed state caches. + + Implements generate-only (s==1) via selective_state_update and prefill/mixed (s>1) via + mamba_chunk_scan_combined, updating the slot-indexed cache in-op. Returns y only. + """ + b, s = hidden_states.shape[:2] + num_seq = seq_len.shape[0] + + if s == 1: + # Generate-only batch: gather cache slices for slots (already sanitized by metadata) + slot_idx_long = slot_idx.to(torch.long) + ssm_batch = ssm_state_cache.index_select(dim=0, index=slot_idx_long) + + # Shapes + batch_size = b + num_heads = hidden_states.shape[2] + head_dim = hidden_states.shape[3] + n_groups = B.shape[2] + ssm_state_size = B.shape[3] + + # Prepare per-head, per-dim tensors + dt_hp = dt[:, 0, :][:, :, None].expand(batch_size, num_heads, head_dim) + dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) + dt_pre = torch.nn.functional.softplus(dt_hp + dt_bias_hp.to(dtype=dt_hp.dtype)) + dt_pre = torch.clamp(dt_pre, time_step_limit[0], time_step_limit[1]) + A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size) + D_full = D[..., None].expand(num_heads, head_dim) + B_grouped = B.reshape(batch_size, n_groups, ssm_state_size) + C_grouped = C.reshape(batch_size, n_groups, ssm_state_size) + x = hidden_states.reshape(batch_size, num_heads, head_dim) + + # compute new state; avoid mutating input cache slice + updated_state = ssm_batch.clone() + # Provide a zero dt_bias tensor to satisfy kernel arg expansion; we've already + # applied dt_bias and softplus/clamp into dt_pre above. + dt_bias_zero = torch.zeros_like(dt_bias_hp) + y_hp = selective_state_update( + updated_state, + x, + dt_pre, + A_full, + B_grouped, + C_grouped, + D=D_full, + z=None, + dt_bias=dt_bias_zero, + dt_softplus=False, + ) + y = y_hp.reshape(batch_size, 1, num_heads, head_dim) + + # Scatter updated states back to global cache + ssm_state_cache.index_copy_(0, slot_idx_long, updated_state.to(ssm_state_cache.dtype)) + + return y.to(hidden_states.dtype) + + # Context/mixed phase (flattened sequences). Expect b == 1, but handle general b robustly. + bs = b * s + flat_idx = torch.arange(bs, device=hidden_states.device, dtype=torch.long) + + # NOTE: use reshape to force contiguous format after reshape + hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) + B_flat = B.reshape(bs, *B.shape[2:]) + C_flat = C.reshape(bs, *C.shape[2:]) + dt_flat = dt.reshape(bs, *dt.shape[2:]) + + # NOTE: need contiguous format to process it sequentially + y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format) + y_flat = y.view(bs, *y.shape[2:]) + + for i in range(num_seq): + length_i = seq_len[i] + if length_i.eq(0): + continue + + start_i = seq_start[i] + end_i = start_i + length_i + + mask_i = (flat_idx >= start_i.to(torch.long)) & (flat_idx < end_i.to(torch.long)) + idx_i = torch.nonzero(mask_i, as_tuple=False).squeeze(-1) + + hs_seq = hs_flat.index_select(0, idx_i).unsqueeze(0) + B_seq = B_flat.index_select(0, idx_i).unsqueeze(0) + C_seq = C_flat.index_select(0, idx_i).unsqueeze(0) + dt_seq = dt_flat.index_select(0, idx_i).unsqueeze(0) + + y_seq, ssm_state_i = mamba_chunk_scan_combined( + hs_seq, + dt_seq, + A, + B_seq, + C_seq, + chunk_size=chunk_size, + D=D, + z=None, + dt_bias=dt_bias, + seq_idx=None, + dt_softplus=True, + dt_limit=(time_step_limit[0], time_step_limit[1]), + return_final_states=True, + ) + + y_flat.index_copy_(0, idx_i, y_seq[0].to(y_flat.dtype)) + + slot_i = slot_idx[i].to(torch.long).unsqueeze(0) + ssm_state_cache.index_copy_(0, slot_i, ssm_state_i.to(ssm_state_cache.dtype)) + + return y + + +@_triton_cached_ssm_transform.register_fake +def _triton_cached_ssm_transform_fake( + # INPUTS (dense but may be flattened across sequences) + hidden_states: torch.Tensor, # [b, s, num_heads, head_dim] + A: torch.Tensor, # [num_heads] + B: torch.Tensor, # [b, s, n_groups, ssm_state_size] + C: torch.Tensor, # [b, s, n_groups, ssm_state_size] + D: torch.Tensor, # [num_heads] + dt: torch.Tensor, # [b, s, num_heads] + dt_bias: torch.Tensor, # [num_heads] + # METADATA + seq_len: torch.Tensor, # [num_seq] + seq_start: torch.Tensor, # [num_seq] + slot_idx: torch.Tensor, # [num_seq] + # CACHES + ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] + # CONSTANTS + time_step_limit: List[float], + chunk_size: int, +): + # Return a correctly-shaped tensor for tracing with fake tensors + return torch.empty_like( + hidden_states, + memory_format=torch.contiguous_format, + dtype=hidden_states.dtype, + ) + + +## Note: we reuse the existing metadata custom op and its registered fake from torch backend. + + +@AttentionRegistry.register("triton_ssm") +class TritonBackendSSM(AttentionDescriptor): + @classmethod + def is_paged(cls) -> bool: + return True + + @classmethod + def get_attention_layout(cls) -> AttentionLayout: + # Hidden states follow [b, s, n, d] + return "bsnd" + + @classmethod + def get_num_qkv_args(cls) -> int: + # torch_ssm_transform signature has 7 node/state arguments + return 7 + + @classmethod + def get_source_attention_op(cls) -> OpOverloadPacket: + # Keep source op unchanged (used for uncached pre-export) + return torch.ops.auto_deploy.torch_ssm_transform + + @classmethod + def get_cached_attention_op(cls) -> MHACallable: + return torch.ops.auto_deploy.triton_cached_ssm_transform + + @classmethod + def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: + # Returns (seq_len, seq_start, slot_idx) + return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 3 + + @classmethod + def get_cache_initializers( + cls, source_attn_node: Node, cache_config: CacheConfig + ) -> CacheInitializerDict: + # Shapes from fake tensors + hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] + B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] + + num_heads = hs_fake.shape[-2] + head_dim = hs_fake.shape[-1] + + if B_fake.ndim >= 4: + ssm_state_size = B_fake.shape[-1] + else: + ssm_state_size = max(1, B_fake.shape[-1]) + + def _get_ssm_cache(si: SequenceInfo): + return torch.empty( + si.max_batch_size, + num_heads, + head_dim, + ssm_state_size, + device=si.device, + dtype=cache_config.dtype or hs_fake.dtype, + ) + + return {"ssm_state_cache": _get_ssm_cache} + + @classmethod + def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: + return {} + + @classmethod + def get_constants(cls, source_attn_node: Node) -> List[Constant]: + time_step_limit, chunk_size = extract_op_args( + source_attn_node, "time_step_limit", "chunk_size" + ) + return [time_step_limit, chunk_size] diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py index 22f202b763d..03c322ef59d 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py @@ -150,6 +150,18 @@ def _bamba_model_update_mamba_mask(self, attention_mask, cache_position): return None +def _bamba_model_update_causal_mask( + self, + attention_mask, + input_tensor, + cache_position, + past_key_values, + output_attentions, +): + # Force attention to use causal mode without explicit masks + return None + + # NOTE: this would need to be applied earlier than other patches, since the `_init_weights` (which # is called by `post_init`) is called before we run `forward`. def _bamba_pretrained_model_init_weights(self, module): @@ -182,17 +194,20 @@ class BambaModelPatch(BaseExportPatch): def _apply_patch(self): self.original_values["BambaMixer.torch_forward"] = BambaMixer.torch_forward self.original_values["BambaModel._update_mamba_mask"] = BambaModel._update_mamba_mask + self.original_values["BambaModel._update_causal_mask"] = BambaModel._update_causal_mask # NOTE: there is `HybridMambaAttentionDynamicCache.__bool__` to save. # self.original_values["BambaPreTrainedModel._init_weights"] = BambaPreTrainedModel._init_weights BambaMixer.torch_forward = _bamba_mixer_torch_forward BambaModel._update_mamba_mask = _bamba_model_update_mamba_mask + BambaModel._update_causal_mask = _bamba_model_update_causal_mask HybridMambaAttentionDynamicCache.__bool__ = _cache_bool # BambaPreTrainedModel._init_weights = _bamba_pretrained_model_init_weights def _revert_patch(self): BambaMixer.torch_forward = self.original_values["BambaMixer.torch_forward"] BambaModel._update_mamba_mask = self.original_values["BambaModel._update_mamba_mask"] + BambaModel._update_causal_mask = self.original_values["BambaModel._update_causal_mask"] del HybridMambaAttentionDynamicCache.__bool__ # BambaPreTrainedModel._init_weights = self.original_values[ # "BambaPreTrainedModel._init_weights" diff --git a/tensorrt_llm/_torch/auto_deploy/shim/interface.py b/tensorrt_llm/_torch/auto_deploy/shim/interface.py index 7a10ade3e11..d4bc4d600ad 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/interface.py @@ -56,6 +56,8 @@ def initialize_caches(self) -> int: self._caches = { name: get_cache(self.info) for name, get_cache in self._cache_initializers.items() } + for name, cache in self._caches.items(): + print(f"{name=}, {cache.shape=}, {cache.dtype=}, {cache.device=}") return len(self._caches) def current_cache_size_bytes(self) -> int: @@ -73,7 +75,8 @@ def resize_cache(self, new_num_pages: int): self.info.num_pages = new_num_pages for name, cache in self._caches.items(): # We assume cache is a tensor of shape (max_batch_size, page_size, n_heads, head_dim) - if "cache" in name: + if "k_cache" in name or "v_cache" in name: + print("resizing cache", name) current_shape = cache.shape new_shape = (new_num_pages, *current_shape[1:]) cache.resize_(new_shape) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index 5be69ad9e8e..1d188e15e79 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -334,6 +334,82 @@ def _grouped_attn_replacement_5(q, k, v, n_rep, attn_mask): return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(q, k, v, attn_mask) +def _grouped_attn_pattern_6(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=scale, + enable_gqa=True, + ) + + +def _grouped_attn_replacement_6(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale + ) + + +def _grouped_attn_pattern_7(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=True, + scale=scale, + enable_gqa=True, + ) + + +def _grouped_attn_replacement_7(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale + ) + + +def _grouped_attn_pattern_8(q, k, v, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + q, + k, + v, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=scale, + enable_gqa=True, + ) + + +def _grouped_attn_replacement_8(q, k, v, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale + ) + + +def _grouped_attn_pattern_9(q, k, v, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + q, + k, + v, + attn_mask=None, + dropout_p=dropout_p, + is_causal=True, + scale=scale, + enable_gqa=True, + ) + + +def _grouped_attn_replacement_9(q, k, v, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale + ) + + @TransformRegistry.register("match_repeat_kv") class MatchRepeatKV(BaseTransform): """ @@ -434,6 +510,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale] dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale] dummy_args_3 = [q, k1, v1, n_rep, attn_mask] + dummy_args_4 = [q, k1, v1, dropout, scale] register_ad_pattern( search_fn=_grouped_attn_pattern_1, @@ -477,6 +554,35 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): scalar_workaround={"n_rep": n_rep}, ) + register_ad_pattern( + search_fn=_grouped_attn_pattern_6, + replace_fn=_grouped_attn_replacement_6, + patterns=patterns, + dummy_args=dummy_args_2, + scalar_workaround={"scale": scale, "dropout_p": dropout}, + ) + register_ad_pattern( + search_fn=_grouped_attn_pattern_7, + replace_fn=_grouped_attn_replacement_7, + patterns=patterns, + dummy_args=dummy_args_2, + scalar_workaround={"scale": scale, "dropout_p": dropout}, + ) + register_ad_pattern( + search_fn=_grouped_attn_pattern_8, + replace_fn=_grouped_attn_replacement_8, + patterns=patterns, + dummy_args=dummy_args_4, + scalar_workaround={"scale": scale, "dropout_p": dropout}, + ) + register_ad_pattern( + search_fn=_grouped_attn_pattern_9, + replace_fn=_grouped_attn_replacement_9, + patterns=patterns, + dummy_args=dummy_args_4, + scalar_workaround={"scale": scale, "dropout_p": dropout}, + ) + num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention) if num_grouped_patterns == 0: ad_logger.warning( @@ -529,7 +635,6 @@ def _apply( # List of SDPA operations to look for sdpa_ops = { - torch.ops.auto_deploy.torch_attention_sdpa, torch.ops.auto_deploy.torch_attention_grouped_sdpa, } diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py new file mode 100644 index 00000000000..b2745e51bfd --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py @@ -0,0 +1,200 @@ +"""Unit tests for CUDA-backed cached causal conv1d custom ops. + +Covers: +- Generate-only path with slot-indexed cache mapping +- Context (flattened) path and state write-back per slot +- Metadata preparation +""" + +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy # noqa: F401 + + +def _random_params_depthwise(device, dtype, batch, seq, channels, k): + x = torch.randn(batch, seq, channels, device=device, dtype=dtype) + # Depthwise: out_channels == in_channels, groups == channels, weight [C, 1, K] + weight = torch.randn(channels, 1, k, device=device, dtype=dtype) + bias = torch.randn(channels, device=device, dtype=dtype) + stride = 1 + padding = k - 1 + dilation = 1 + groups = channels + padding_mode = "zeros" + return x, weight, bias, stride, padding, dilation, groups, padding_mode + + +@pytest.fixture +def conv_env(): + device = "cuda" + dtype = torch.float16 + atol = 5e-2 + rtol = 5e-2 + torch.manual_seed(123) + torch.cuda.empty_cache() + return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol} + + +def test_generate_only_with_slot_mapping_cuda(conv_env): + device = conv_env["device"] + dtype = conv_env["dtype"] + + batch, seq = 1, 1 + c, k = 2, 3 + x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k) + + # Slot mapping with arbitrary order within max_batch_size + max_batch_size = 2 + slot_idx = torch.tensor([0], device=device, dtype=torch.int32) + # Cache holds K-1 entries (TRT update kernel contract) + conv_state_cache = torch.randn( + max_batch_size, + c, + k - 1, + device=device, + dtype=dtype, + ) + + # Metadata (not used in generate-only op entry, but required by the interface) + seq_len = torch.ones(batch, device=device, dtype=torch.int32) + seq_start = torch.zeros(batch, device=device, dtype=torch.int32) + + # Snapshot caches for reference before running op (op mutates caches) + gathered_before = conv_state_cache.clone().index_select(0, slot_idx) + x_ref = x.clone() + # Run CUDA cached op + y = torch.ops.auto_deploy.cuda_cached_causal_conv1d( + # INPUTS + x, + w, + b, + # METADATA + seq_len, + seq_start, + slot_idx, + # CACHES + conv_state_cache, + # CONSTANTS + s, + p, + d, + g, + pm, + ) + + assert y.shape == (batch, seq, c) + assert torch.isfinite(y).all() + + # Reference via torch uncached conv on window [state(K-1) | x] + window_bt_c = torch.cat([gathered_before.transpose(1, 2), x_ref], dim=-2) + y_ref = torch.ops.auto_deploy.torch_causal_conv1d(window_bt_c, w, b, 1, 0, 1, g, pm) + assert torch.allclose(y, y_ref, atol=conv_env["atol"], rtol=conv_env["rtol"]) + after = conv_state_cache.index_select(0, slot_idx) + expected_after = torch.cat([gathered_before[..., 1:], x_ref.transpose(1, 2)[..., :1]], dim=-1) + assert torch.allclose( + after, expected_after.to(after.dtype), atol=conv_env["atol"], rtol=conv_env["rtol"] + ) + + +def test_context_flattened_and_state_writeback_cuda(conv_env): + device = conv_env["device"] + dtype = conv_env["dtype"] + + # Two short sequences with lengths 2 and 1, flattened to [1,3] + lens = [2, 1] + total = sum(lens) + batch, seq = 1, total + c, k = 2, 3 + x, w, b, s, p, d, g, pm = _random_params_depthwise(device, dtype, batch, seq, c, k) + + max_batch_size = 2 + slot_idx = torch.tensor([1, 0], device=device, dtype=torch.int32) + conv_state_cache = torch.randn( + max_batch_size, + c, + k - 1, + device=device, + dtype=dtype, + ) + + seq_len = torch.tensor(lens, device=device, dtype=torch.int32) + seq_start = torch.tensor([0, lens[0]], device=device, dtype=torch.int32) + + y = torch.ops.auto_deploy.cuda_cached_causal_conv1d( + # INPUTS + x, + w, + b, + # METADATA + seq_len, + seq_start, + slot_idx, + # CACHES + conv_state_cache, + # CONSTANTS + s, + p, + d, + g, + pm, + ) + + assert y.shape == (batch, seq, c) + assert torch.isfinite(y).all() + + # Reference by per-sequence prefill output and expected conv state (K-1 window) + y_ref = torch.empty_like(y) + for i, ln in enumerate(lens): + st = 0 if i == 0 else lens[0] + x_i = x[:, st : st + ln] + y_i, _ = ( + tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_prefill( # type: ignore # noqa: E501 + x_i, w, b, s, p, d, g, pm + ) + ) + y_ref[:, st : st + ln].copy_(y_i) + # Cache should hold K-1 latest inputs + x_b_c_t = x_i.transpose(1, 2) + if ln >= (k - 1): + expected_state = x_b_c_t[..., -(k - 1) :] + else: + pad = (k - 1) - ln + expected_state = torch.nn.functional.pad(x_b_c_t, (pad, 0)) + assert torch.allclose( + conv_state_cache[slot_idx[i]].to(expected_state.dtype), + expected_state, + atol=conv_env["atol"], + rtol=conv_env["rtol"], + ) + + assert torch.allclose(y, y_ref.to(y.dtype), atol=conv_env["atol"], rtol=conv_env["rtol"]) + + +def test_prepare_metadata_cuda(conv_env): + device = conv_env["device"] + + b, s = 4, 6 + input_ids = torch.randint(0, 1000, (b, s), device=device) + position_ids = torch.arange(s, device=device).expand(b, -1) + seq_len = torch.tensor([2, 1, 0, 0], device=device, dtype=torch.int32) + input_pos = torch.tensor([0, 3, 0, 0], device=device, dtype=torch.int32) + cache_loc = torch.arange(b, device=device, dtype=torch.int32) + pages_per_seq = torch.ones(b, device=device, dtype=torch.int32) + slot_idx = torch.tensor([2, 0, 1, 3], device=device, dtype=torch.int32) + page_size = 128 + + out = torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata( + input_ids, + position_ids, + seq_len, + input_pos, + cache_loc, + pages_per_seq, + slot_idx, + page_size, + ) + assert len(out) == 3 + seq_len_s, seq_start, slot_s = out + assert seq_len_s.numel() == 2 and slot_s.numel() == 2 + assert torch.all(seq_start == torch.tensor([0, 2], device=device, dtype=seq_start.dtype)) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py new file mode 100644 index 00000000000..97a23572206 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py @@ -0,0 +1,169 @@ +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy # noqa: F401 + + +def _random_params(device, dtype, batch, seq, num_heads, head_dim, n_groups, ssm_state_size): + hidden_states = torch.randn(batch, seq, num_heads, head_dim, device=device, dtype=dtype) + A = torch.randn(num_heads, device=device, dtype=torch.float32) + B = torch.randn(batch, seq, n_groups, ssm_state_size, device=device, dtype=dtype) + C = torch.randn(batch, seq, n_groups, ssm_state_size, device=device, dtype=dtype) + D = torch.randn(num_heads, device=device, dtype=dtype) + dt = torch.randn(batch, seq, num_heads, device=device, dtype=dtype) + dt_bias = torch.randn(num_heads, device=device, dtype=dtype) + time_step_limit = [1e-6, 1.0] + chunk_size = 4 + return hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size + + +@pytest.fixture +def mamba_env(): + device = "cuda" + dtype = torch.float16 + atol = 5e-2 + rtol = 5e-2 + torch.manual_seed(123) + torch.cuda.empty_cache() + return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol} + + +def test_triton_generate_only_with_slot_mapping(mamba_env): + device = mamba_env["device"] + dtype = mamba_env["dtype"] + atol = mamba_env["atol"] + rtol = mamba_env["rtol"] + + batch, seq = 3, 1 + num_heads, head_dim = 4, 8 + n_groups, ssm_state_size = 2, 4 + (hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) = _random_params( + device, dtype, batch, seq, num_heads, head_dim, n_groups, ssm_state_size + ) + + max_batch_size = 6 + slot_idx = torch.tensor([4, 1, 3], device=device, dtype=torch.int32) + ssm_state_cache_torch = torch.randn( + max_batch_size, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype + ) + ssm_state_cache_triton = ssm_state_cache_torch.clone() + + seq_len = torch.ones(batch, device=device, dtype=torch.int32) + seq_start = torch.zeros(batch, device=device, dtype=torch.int32) + + # Torch reference + y_torch = torch.ops.auto_deploy.torch_cached_ssm_transform( + hidden_states, + A, + B, + C, + D, + dt, + dt_bias, + seq_len, + seq_start, + slot_idx, + ssm_state_cache_torch, + time_step_limit, + chunk_size, + ) + + # Triton under test + y_triton = torch.ops.auto_deploy.triton_cached_ssm_transform( + hidden_states, + A, + B, + C, + D, + dt, + dt_bias, + seq_len, + seq_start, + slot_idx, + ssm_state_cache_triton, + time_step_limit, + chunk_size, + ) + + assert y_triton.shape == hidden_states.shape + assert torch.isfinite(y_triton).all() + + # Compare outputs + assert torch.allclose(y_triton, y_torch.to(y_triton.dtype), atol=atol, rtol=rtol) + + # Compare cache updates at slots + after_torch = ssm_state_cache_torch.index_select(0, slot_idx) + after_triton = ssm_state_cache_triton.index_select(0, slot_idx) + assert torch.allclose(after_triton.to(after_torch.dtype), after_torch, atol=atol, rtol=rtol) + + +def test_triton_context_flattened_and_state_writeback(mamba_env): + device = mamba_env["device"] + dtype = mamba_env["dtype"] + atol = mamba_env["atol"] + rtol = mamba_env["rtol"] + + lens = [2] + total = sum(lens) + batch, seq = 1, total + num_heads, head_dim = 1, 4 + n_groups, ssm_state_size = 1, 1 + (hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) = _random_params( + device, dtype, batch, seq, num_heads, head_dim, n_groups, ssm_state_size + ) + + max_batch_size = 2 + slot_idx = torch.tensor([1], device=device, dtype=torch.int32) + ssm_state_cache_torch = torch.randn( + max_batch_size, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype + ) + ssm_state_cache_triton = ssm_state_cache_torch.clone() + + seq_len = torch.tensor(lens, device=device, dtype=torch.int32) + seq_start = torch.tensor([0, lens[0]], device=device, dtype=torch.int32) + + # Torch reference + y_torch = torch.ops.auto_deploy.torch_cached_ssm_transform( + hidden_states, + A, + B, + C, + D, + dt, + dt_bias, + seq_len, + seq_start, + slot_idx, + ssm_state_cache_torch, + time_step_limit, + chunk_size, + ) + + # Triton under test + y_triton = torch.ops.auto_deploy.triton_cached_ssm_transform( + hidden_states, + A, + B, + C, + D, + dt, + dt_bias, + seq_len, + seq_start, + slot_idx, + ssm_state_cache_triton, + time_step_limit, + chunk_size, + ) + + assert y_triton.shape == hidden_states.shape + assert torch.isfinite(y_triton).all() + # Compare outputs + assert torch.allclose(y_triton, y_torch.to(y_triton.dtype), atol=1e-1, rtol=1e-1) + + # Cache should hold final state at slots + for i, ln in enumerate(lens): + slot = slot_idx[i] + state_torch = ssm_state_cache_torch[slot] + state_triton = ssm_state_cache_triton[slot] + assert torch.allclose(state_triton.to(state_torch.dtype), state_torch, atol=atol, rtol=rtol) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_bamba.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_bamba.py index 7ce6fe659e9..8b94dd4112c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_bamba.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_bamba.py @@ -52,6 +52,10 @@ def test_bamba_patches( }, ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + factory = llm_args.create_factory() model = factory.build_model("meta") tokenizer = factory.init_tokenizer()