diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py index ccf53d4a487..8098a619875 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py @@ -8,3 +8,9 @@ for _, module_name, is_pkg in pkgutil.iter_modules(__path__): __all__.append(module_name) importlib.import_module(f"{__name__}.{module_name}") + +# Recursively import subpackages and modules so their side-effects (e.g., +# op registrations) are applied even when nested in subdirectories. +for _, full_name, _ in pkgutil.walk_packages(__path__, prefix=f"{__name__}."): + __all__.append(full_name) + importlib.import_module(full_name) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py similarity index 93% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index b8e134be19f..014f8cc7e6b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -18,8 +18,8 @@ from tensorrt_llm._torch.modules.mamba import PAD_SLOT_ID from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from ..utils.node_utils import extract_op_args -from .attention_interface import ( +from ...utils.node_utils import extract_op_args +from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, @@ -74,8 +74,9 @@ def cuda_causal_conv_prepare_metadata( seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0) slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long) - - return (seq_len_sanitized, seq_start, slot_idx_sanitized) + # This is only used during prefill to determine if we should use the initial states from the cache. + use_initial_states = input_pos > 0 + return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states) @cuda_causal_conv_prepare_metadata.register_fake @@ -88,6 +89,7 @@ def cuda_causal_conv_prepare_metadata_fake( torch.empty_like(seq_len_sanitized), torch.empty_like(seq_len_sanitized), torch.empty(num_seq, dtype=torch.long, device=slot_idx.device), + torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device), ) @@ -101,6 +103,7 @@ def _cuda_cached_causal_conv1d( seq_len: torch.Tensor, # [num_seq] seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] + use_initial_states: torch.Tensor, # [num_seq] # CACHES conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k-1] # CONSTANTS @@ -161,7 +164,7 @@ def _cuda_cached_causal_conv1d( dim=0, ).contiguous() cache_indices = slot_idx[:num_prefill].to(torch.int32).contiguous() - has_initial_state = torch.zeros(num_prefill, dtype=torch.bool, device=input.device) + has_initial_state = use_initial_states[:num_prefill].to(torch.bool) # Run varlen conv; updates conv_state_cache in-place per cache_indices y_varlen = causal_conv1d_fn( @@ -215,6 +218,7 @@ def _cuda_cached_causal_conv1d_fake( seq_len: torch.Tensor, seq_start: torch.Tensor, slot_idx: torch.Tensor, + use_initial_states: torch.Tensor, # [num_seq] # CACHES conv_state_cache: torch.Tensor, # CONSTANTS @@ -256,8 +260,8 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns (seq_len, seq_start, slot_idx) - return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 3 + # Returns (seq_len, seq_start, slot_idx, use_initial_states) + return torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata, 4 @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py similarity index 95% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index 6aaf5ecb405..a204c559f00 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -16,8 +16,8 @@ from torch._ops import OpOverloadPacket from torch.fx import Node -from ..utils.node_utils import extract_op_args -from .attention_interface import ( +from ...utils.node_utils import extract_op_args +from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, @@ -160,8 +160,8 @@ def torch_causal_conv_prepare_metadata( seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0) slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long) - - return (seq_len_sanitized, seq_start, slot_idx_sanitized) + use_initial_states = input_pos > 0 + return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states) @torch_causal_conv_prepare_metadata.register_fake @@ -174,6 +174,7 @@ def torch_causal_conv_prepare_metadata_fake( torch.empty_like(seq_len_sanitized), torch.empty_like(seq_len_sanitized), torch.empty(num_seq, dtype=torch.long, device=slot_idx.device), + torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device), ) @@ -187,6 +188,7 @@ def _torch_cached_causal_conv1d( seq_len: torch.Tensor, # [num_seq] seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] + use_initial_states: torch.Tensor, # [num_seq] # CACHES conv_state_cache: torch.Tensor, # [max_batch_size, c_in, k] # CONSTANTS @@ -275,6 +277,7 @@ def _torch_cached_causal_conv1d_fake( seq_len: torch.Tensor, seq_start: torch.Tensor, slot_idx: torch.Tensor, + use_initial_states: torch.Tensor, # [num_seq] # CACHES conv_state_cache: torch.Tensor, # CONSTANTS @@ -317,8 +320,10 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch + # reference implementation to support chunked prefill. # Returns (seq_len, seq_start, slot_idx) - return torch.ops.auto_deploy.torch_causal_conv_prepare_metadata, 3 + return torch.ops.auto_deploy.torch_causal_conv_prepare_metadata, 4 @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py similarity index 90% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py index 6bf7eb84d14..ccd24e7ec00 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py @@ -12,8 +12,8 @@ from torch._ops import OpOverloadPacket from torch.fx import Node -from ..utils.node_utils import extract_op_args -from .attention_interface import ( +from ...utils.node_utils import extract_op_args +from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, @@ -25,10 +25,10 @@ PrepareMetadataCallable, SequenceInfo, ) -from .torch_mamba import _torch_ssm_transform_prefill +from .torch_mamba import _torch_ssm_prefill -def _torch_cached_ssm_transform_decode( +def _torch_cached_ssm_decode( hidden_states: torch.Tensor, A: torch.Tensor, B: torch.Tensor, @@ -135,8 +135,10 @@ def _torch_ssm_prepare_metadata( # Truncate slot indices to match active sequences slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long) - - return (seq_len_sanitized, seq_start, slot_idx_sanitized) + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch + # reference implementation to support chunked prefill. + use_initial_states = input_pos > 0 + return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states) @_torch_ssm_prepare_metadata.register_fake @@ -150,11 +152,12 @@ def _torch_ssm_prepare_metadata_fake( torch.empty_like(seq_len_sanitized), torch.empty_like(seq_len_sanitized), torch.empty(num_seq, dtype=torch.long, device=slot_idx.device), + torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device), ) -@torch.library.custom_op("auto_deploy::torch_cached_ssm_transform", mutates_args={}) -def _torch_cached_ssm_transform( +@torch.library.custom_op("auto_deploy::torch_cached_ssm", mutates_args={}) +def _torch_cached_ssm( # INPUTS (dense but may be flattened across sequences) hidden_states: torch.Tensor, # [b, s, num_heads, head_dim] A: torch.Tensor, # [num_heads] @@ -167,6 +170,7 @@ def _torch_cached_ssm_transform( seq_len: torch.Tensor, # [num_seq] seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] + use_initial_states: torch.Tensor, # [num_seq] # CACHES ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] # CONSTANTS @@ -188,7 +192,7 @@ def _torch_cached_ssm_transform( slot_idx_long = slot_idx.to(torch.long) ssm_batch = ssm_state_cache.index_select(dim=0, index=slot_idx_long) - y, updated_state = _torch_cached_ssm_transform_decode( + y, updated_state = _torch_cached_ssm_decode( hidden_states, A, B, @@ -207,6 +211,14 @@ def _torch_cached_ssm_transform( # return in the same dtype as the input return y.to(hidden_states.dtype) + # Prefill + if any(use_initial_states): + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch + # reference implementation to support chunked prefill. + raise ValueError( + "torch mamba backend does not yet support chunked prefill " + "and can not correctly handle initial states." + ) # Context/mixed phase (flattened sequences). Expect b == 1, but handle general b robustly. # We'll iterate over sequences defined by (seq_len, seq_start) and update state per slot. # Process across the flattened second dimension. @@ -244,7 +256,7 @@ def _torch_cached_ssm_transform( dt_seq = dt_flat.index_select(0, idx_i).unsqueeze(0) # Run prefill and obtain final SSM state for this sequence - y_seq, ssm_state_i = _torch_ssm_transform_prefill( + y_seq, ssm_state_i = _torch_ssm_prefill( hs_seq, A, B_seq, C_seq, D, dt_seq, dt_bias, time_step_limit, chunk_size ) @@ -258,8 +270,8 @@ def _torch_cached_ssm_transform( return y -@_torch_cached_ssm_transform.register_fake -def _torch_cached_ssm_transform_fake( +@_torch_cached_ssm.register_fake +def _torch_cached_ssm_fake( # INPUTS hidden_states: torch.Tensor, A: torch.Tensor, @@ -272,6 +284,7 @@ def _torch_cached_ssm_transform_fake( seq_len: torch.Tensor, seq_start: torch.Tensor, slot_idx: torch.Tensor, + use_initial_states: torch.Tensor, # CACHES ssm_state_cache: torch.Tensor, # CONSTANTS @@ -304,16 +317,16 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: - return torch.ops.auto_deploy.torch_ssm_transform + return torch.ops.auto_deploy.torch_ssm @classmethod def get_cached_attention_op(cls) -> MHACallable: - return torch.ops.auto_deploy.torch_cached_ssm_transform + return torch.ops.auto_deploy.torch_cached_ssm @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: # Returns (seq_len, seq_start, slot_idx) - return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 3 + return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4 @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py similarity index 100% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/torch_causal_conv.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_causal_conv.py diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py similarity index 95% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py index 7deeced93eb..752520a74ae 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_mamba.py @@ -67,7 +67,7 @@ def _segment_sum(input_tensor): return tensor_segsum -def _torch_ssm_transform_prefill( +def _torch_ssm_prefill( hidden_states: torch.Tensor, A: torch.Tensor, B: torch.Tensor, @@ -162,8 +162,8 @@ def _torch_ssm_transform_prefill( return y, ssm_state -@torch.library.custom_op("auto_deploy::torch_ssm_transform", mutates_args={}) -def _torch_ssm_transform( +@torch.library.custom_op("auto_deploy::torch_ssm", mutates_args={}) +def _torch_ssm( hidden_states: torch.Tensor, A: torch.Tensor, B: torch.Tensor, @@ -176,14 +176,12 @@ def _torch_ssm_transform( ], # NOTE: `torch` custom ops do not like `Tuple` inputs. Using `List` is the suggested WAR. chunk_size: int, ) -> torch.Tensor: - y, _ = _torch_ssm_transform_prefill( - hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size - ) + y, _ = _torch_ssm_prefill(hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) return y -@_torch_ssm_transform.register_fake -def _torch_ssm_transform_meta( +@_torch_ssm.register_fake +def _torch_ssm_meta( hidden_states: torch.Tensor, A: torch.Tensor, B: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py similarity index 87% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 9cf141ce24d..9edf1ce6836 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -5,11 +5,12 @@ from torch.fx import Node # Triton kernels +from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined -from ..utils.node_utils import extract_op_args -from .attention_interface import ( +from ...utils.node_utils import extract_op_args +from ..attention_interface import ( AttentionDescriptor, AttentionLayout, AttentionRegistry, @@ -23,8 +24,8 @@ ) -@torch.library.custom_op("auto_deploy::triton_cached_ssm_transform", mutates_args={}) -def _triton_cached_ssm_transform( +@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={}) +def _triton_cached_ssm( # INPUTS (dense but may be flattened across sequences) hidden_states: torch.Tensor, # [b, s, num_heads, head_dim] A: torch.Tensor, # [num_heads] @@ -37,6 +38,7 @@ def _triton_cached_ssm_transform( seq_len: torch.Tensor, # [num_seq] seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] + use_initial_states: torch.Tensor, # [num_seq] # CACHES ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] # CONSTANTS @@ -51,7 +53,6 @@ def _triton_cached_ssm_transform( """ b, s = hidden_states.shape[:2] num_seq = seq_len.shape[0] - # Flatten tokens for indexing/scatter bs = b * s device = hidden_states.device @@ -96,6 +97,16 @@ def _triton_cached_ssm_transform( seq_ids = torch.arange(num_prefill, device=device, dtype=torch.int32) seq_idx_prefill = torch.repeat_interleave(seq_ids, seq_len_prefill).view(1, -1) + initial_states = chunk_indices = chunk_offsets = None + if torch.any(use_initial_states[:num_prefill]): + initial_states = torch.where( + use_initial_states[:num_prefill, None, None, None], + ssm_state_cache[slot_idx[:num_prefill]], + 0, + ) + chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + cu_seqlens, chunk_size + ) y_prefill, varlen_states = mamba_chunk_scan_combined( hs_prefill, dt_prefill, @@ -106,10 +117,10 @@ def _triton_cached_ssm_transform( D=D, z=None, dt_bias=dt_bias, - initial_states=None, + initial_states=initial_states, seq_idx=seq_idx_prefill, - chunk_indices=None, - chunk_offsets=None, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, dt_softplus=True, dt_limit=(time_step_limit[0], time_step_limit[1]), @@ -159,8 +170,8 @@ def _triton_cached_ssm_transform( return y -@_triton_cached_ssm_transform.register_fake -def _triton_cached_ssm_transform_fake( +@_triton_cached_ssm.register_fake +def _triton_cached_ssm_fake( # INPUTS (dense but may be flattened across sequences) hidden_states: torch.Tensor, # [b, s, num_heads, head_dim] A: torch.Tensor, # [num_heads] @@ -173,6 +184,7 @@ def _triton_cached_ssm_transform_fake( seq_len: torch.Tensor, # [num_seq] seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] + use_initial_states: torch.Tensor, # [num_seq] # CACHES ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] # CONSTANTS @@ -209,16 +221,16 @@ def get_num_qkv_args(cls) -> int: @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: # Keep source op unchanged (used for uncached pre-export) - return torch.ops.auto_deploy.torch_ssm_transform + return torch.ops.auto_deploy.torch_ssm @classmethod def get_cached_attention_op(cls) -> MHACallable: - return torch.ops.auto_deploy.triton_cached_ssm_transform + return torch.ops.auto_deploy.triton_cached_ssm @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns (seq_len, seq_start, slot_idx) - return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 3 + # Returns (seq_len, seq_start, slot_idx, use_initial_states) + return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4 @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/common.py b/tensorrt_llm/_torch/auto_deploy/distributed/common.py index dec42d8386b..1585d1b0da5 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/common.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/common.py @@ -101,6 +101,10 @@ def is_torchelastic(): return "TORCHELASTIC_RUN_ID" in os.environ +def is_initialized(): + return dist.is_initialized() + + def cleanup(): """Destroy process group when the program exits.""" if dist.is_initialized(): diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py index 03c322ef59d..85e997c615b 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py @@ -48,7 +48,7 @@ def _bamba_mixer_torch_forward( 0, batch_size * seq_len, seq_len, device=input_states.device, dtype=torch.int ) slot_idx_t = torch.arange(batch_size, device=input_states.device, dtype=torch.long) - + use_initial_states_t = torch.zeros(batch_size, device=input_states.device, dtype=torch.bool) if use_caching: hidden_states_B_C = self.act( torch.ops.auto_deploy.torch_cached_causal_conv1d( @@ -60,6 +60,7 @@ def _bamba_mixer_torch_forward( seq_len_t, seq_start_t, slot_idx_t, + use_initial_states_t, # CACHES cache_params.conv_states[self.layer_idx], # CONSTANTS @@ -100,7 +101,7 @@ def _bamba_mixer_torch_forward( if use_caching: # Use new flattened cached op for both cache updates and outputs - y = torch.ops.auto_deploy.torch_cached_ssm_transform( + y = torch.ops.auto_deploy.torch_cached_ssm( # INPUTS hidden_states=hidden_states.view(batch_size, seq_len, -1, self.head_dim), A=A, @@ -113,6 +114,7 @@ def _bamba_mixer_torch_forward( seq_len=seq_len_t, seq_start=seq_start_t, slot_idx=slot_idx_t, + use_initial_states=use_initial_states_t, # CACHES ssm_state_cache=cache_params.ssm_states[self.layer_idx], # CONSTANTS @@ -120,7 +122,7 @@ def _bamba_mixer_torch_forward( chunk_size=self.chunk_size, ) else: - y = torch.ops.auto_deploy.torch_ssm_transform( + y = torch.ops.auto_deploy.torch_ssm( hidden_states=hidden_states.view(batch_size, seq_len, -1, self.head_dim), A=A, B=B.view(batch_size, seq_len, -1, self.ssm_state_size), diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 6116a56c791..8aea96a4c4d 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -10,6 +10,7 @@ from ...custom_ops.attention_interface import AttentionDescriptor, AttentionRegistry, Constant from ...distributed.common import all_gather_object, get_world_size +from ...distributed.common import is_initialized as is_distributed_initialized from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils._graph import add_graph_input @@ -283,12 +284,15 @@ def _get_mem_info_in_mb(): new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages)) - # Need to sync all the GPUs - gathered_num_pages = [None] * get_world_size() - all_gather_object(gathered_num_pages, new_num_pages) - new_num_pages = min(gathered_num_pages) - self._log_info(f"After all_gather - new_num_pages: {new_num_pages}") + # Need to sync all the GPUs if distributed group is initialized + log_msg = f"Using local new_num_pages: {new_num_pages}" + if is_distributed_initialized(): + gathered_num_pages = [None] * get_world_size() + all_gather_object(gathered_num_pages, new_num_pages) + new_num_pages = min(gathered_num_pages) + log_msg = f"After all_gather - new_num_pages: {new_num_pages}" + self._log_info(log_msg) cm.resize_cache(new_num_pages) # Log the final cache size for performance measurement, do not remove this log. @@ -326,5 +330,4 @@ def _apply_to_full_model( info = TransformInfo( skipped=False, num_matches=num_caches, is_clean=True, has_valid_shapes=True ) - return mod, info diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 338c04608c7..20862a946fc 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -118,7 +118,7 @@ def get_default_kwargs(self, enable_chunked_prefill=False): "free_mem_ratio": 0.7 }, "compile_model": { - "backend": "torch-opt", + "backend": "torch-cudagraph", "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], }, }, @@ -140,9 +140,6 @@ def get_default_sampling_params(self): @pytest.mark.skip_less_device_memory(32000) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) def test_auto_dtype(self, enable_chunked_prefill): - if enable_chunked_prefill: - pytest.skip( - "see https://github.com/NVIDIA/TensorRT-LLM/issues/8272") kwargs = self.get_default_kwargs(enable_chunked_prefill) sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_PATH, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py index 7ffb1709cb6..2c9e4a70720 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py @@ -59,7 +59,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env): # Metadata (not used in generate-only op entry, but required by the interface) seq_len = torch.ones(batch, device=device, dtype=torch.int32) seq_start = torch.zeros(batch, device=device, dtype=torch.int32) - + use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) # Snapshot caches for reference before running op (op mutates caches) gathered_before = conv_state_cache.clone().index_select(0, slot_idx) x_ref = x.clone() @@ -73,6 +73,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env): seq_len, seq_start, slot_idx, + use_initial_states, # CACHES conv_state_cache, # CONSTANTS @@ -194,7 +195,7 @@ def test_prepare_metadata_cuda(conv_env): slot_idx, page_size, ) - assert len(out) == 3 - seq_len_s, seq_start, slot_s = out + assert len(out) == 4 + seq_len_s, seq_start, slot_s, use_initial_states = out assert seq_len_s.numel() == 2 and slot_s.numel() == 2 assert torch.all(seq_start == torch.tensor([0, 2], device=device, dtype=seq_start.dtype)) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py index 4090821e252..3988595346b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py @@ -56,10 +56,9 @@ def test_generate_only_with_slot_mapping(conv_env): # Metadata (not used in generate-only op entry, but required by the interface) seq_len = torch.ones(batch, device=device, dtype=torch.int32) seq_start = torch.zeros(batch, device=device, dtype=torch.int32) - # Snapshot caches for reference before running op (op mutates caches) gathered_before = conv_state_cache.clone().index_select(0, slot_idx) - + use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) # Run cached op y = torch.ops.auto_deploy.torch_cached_causal_conv1d( # INPUTS @@ -70,6 +69,7 @@ def test_generate_only_with_slot_mapping(conv_env): seq_len, seq_start, slot_idx, + use_initial_states, # CACHES conv_state_cache, # CONSTANTS @@ -85,7 +85,7 @@ def test_generate_only_with_slot_mapping(conv_env): # Reference: use pre-op gathered states, run decode helper directly, compare y_ref, updated = ( - tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_decode( # type: ignore # noqa: E501 + tensorrt_llm._torch.auto_deploy.custom_ops.mamba.torch_backend_causal_conv._torch_causal_conv1d_decode( # type: ignore # noqa: E501 x, w, b, s, p, d, g, pm, gathered_before ) ) @@ -119,7 +119,7 @@ def test_context_flattened_and_state_writeback(conv_env): seq_len = torch.tensor(lens, device=device, dtype=torch.int32) seq_start = torch.tensor([0, lens[0]], device=device, dtype=torch.int32) - + use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) y = torch.ops.auto_deploy.torch_cached_causal_conv1d( # INPUTS x, @@ -129,6 +129,7 @@ def test_context_flattened_and_state_writeback(conv_env): seq_len, seq_start, slot_idx, + use_initial_states, # CACHES conv_state_cache, # CONSTANTS @@ -148,7 +149,7 @@ def test_context_flattened_and_state_writeback(conv_env): st = 0 if i == 0 else lens[0] x_i = x[:, st : st + ln] y_i, s_i = ( - tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_causal_conv._torch_causal_conv1d_prefill( # type: ignore # noqa: E501 + tensorrt_llm._torch.auto_deploy.custom_ops.mamba.torch_backend_causal_conv._torch_causal_conv1d_prefill( # type: ignore # noqa: E501 x_i, w, b, s, p, d, g, pm ) ) @@ -186,7 +187,7 @@ def test_prepare_metadata(conv_env): slot_idx, page_size, ) - assert len(out) == 3 - seq_len_s, seq_start, slot_s = out + assert len(out) == 4 + seq_len_s, seq_start, slot_s, use_initial_states = out assert seq_len_s.numel() == 2 and slot_s.numel() == 2 assert torch.all(seq_start == torch.tensor([0, 2], device=device, dtype=seq_start.dtype)) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py index 3000880d435..c49d6cbd35b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py @@ -65,12 +65,12 @@ def test_generate_only_with_slot_mapping(mamba_env): # Metadata seq_len = torch.ones(batch, device=device, dtype=torch.int32) seq_start = torch.zeros(batch, device=device, dtype=torch.int32) - + use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) # Snapshot caches for reference before running op (op mutates caches) gathered_before = ssm_state_cache.clone().index_select(0, slot_idx) # Run cached op - y = torch.ops.auto_deploy.torch_cached_ssm_transform( + y = torch.ops.auto_deploy.torch_cached_ssm( # INPUTS hidden_states, A, @@ -83,6 +83,7 @@ def test_generate_only_with_slot_mapping(mamba_env): seq_len, seq_start, slot_idx, + use_initial_states, # CACHES ssm_state_cache, # CONSTANTS @@ -95,7 +96,7 @@ def test_generate_only_with_slot_mapping(mamba_env): # Reference: use pre-op gathered states, run decode helper directly, compare y_ref, updated = ( - tensorrt_llm._torch.auto_deploy.custom_ops.torch_backend_mamba._torch_cached_ssm_transform_decode( # type: ignore # noqa: E501 + tensorrt_llm._torch.auto_deploy.custom_ops.mamba.torch_backend_mamba._torch_cached_ssm_decode( # type: ignore # noqa: E501 hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size, gathered_before ) ) @@ -135,8 +136,8 @@ def test_context_flattened_and_state_writeback(mamba_env): seq_len = torch.tensor(lens, device=device, dtype=torch.int32) seq_start = torch.tensor([0, lens[0]], device=device, dtype=torch.int32) - - y = torch.ops.auto_deploy.torch_cached_ssm_transform( + use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) + y = torch.ops.auto_deploy.torch_cached_ssm( # INPUTS hidden_states, A, @@ -149,6 +150,7 @@ def test_context_flattened_and_state_writeback(mamba_env): seq_len, seq_start, slot_idx, + use_initial_states, # CACHES ssm_state_cache, # CONSTANTS @@ -167,10 +169,8 @@ def test_context_flattened_and_state_writeback(mamba_env): Bb = B[:, st : st + ln] Cb = C[:, st : st + ln] dtb = dt[:, st : st + ln] - y_i, s_i = ( - tensorrt_llm._torch.auto_deploy.custom_ops.torch_mamba._torch_ssm_transform_prefill( # type: ignore # noqa: E501 - hs, A, Bb, Cb, D, dtb, dt_bias, time_step_limit, chunk_size - ) + y_i, s_i = tensorrt_llm._torch.auto_deploy.custom_ops.mamba.torch_mamba._torch_ssm_prefill( # type: ignore # noqa: E501 + hs, A, Bb, Cb, D, dtb, dt_bias, time_step_limit, chunk_size ) y_ref[:, st : st + ln].copy_(y_i) # Cache should hold final state at slot @@ -202,7 +202,7 @@ def test_prepare_metadata(mamba_env): page_size, ) # Returns a list of tensors from custom op API - assert len(out) == 3 - seq_len_s, seq_start, slot_s = out + assert len(out) == 4 + seq_len_s, seq_start, slot_s, use_initial_states = out assert seq_len_s.numel() == 2 and slot_s.numel() == 2 assert torch.all(seq_start == torch.tensor([0, 2], device=device, dtype=seq_start.dtype)) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py index 6cce60f1684..938f33d85c4 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py @@ -53,7 +53,7 @@ def test_triton_generate_only_with_slot_mapping(mamba_env): seq_start = torch.zeros(batch, device=device, dtype=torch.int32) # Torch reference - y_torch = torch.ops.auto_deploy.torch_cached_ssm_transform( + y_torch = torch.ops.auto_deploy.torch_cached_ssm( hidden_states, A, B, @@ -70,7 +70,7 @@ def test_triton_generate_only_with_slot_mapping(mamba_env): ) # Triton under test - y_triton = torch.ops.auto_deploy.triton_cached_ssm_transform( + y_triton = torch.ops.auto_deploy.triton_cached_ssm( hidden_states, A, B, @@ -122,9 +122,9 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): seq_len = torch.tensor(lens, device=device, dtype=torch.int32) seq_start = torch.tensor([0, lens[0]], device=device, dtype=torch.int32) - + use_initial_states = torch.tensor([0] * batch, device=device).to(torch.bool) # Torch reference - y_torch = torch.ops.auto_deploy.torch_cached_ssm_transform( + y_torch = torch.ops.auto_deploy.torch_cached_ssm( hidden_states, A, B, @@ -135,13 +135,14 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): seq_len, seq_start, slot_idx, + use_initial_states, ssm_state_cache_torch, time_step_limit, chunk_size, ) # Triton under test - y_triton = torch.ops.auto_deploy.triton_cached_ssm_transform( + y_triton = torch.ops.auto_deploy.triton_cached_ssm( hidden_states, A, B, @@ -152,6 +153,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): seq_len, seq_start, slot_idx, + use_initial_states, ssm_state_cache_triton, time_step_limit, chunk_size, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 416fe584785..241e159c770 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -190,8 +190,6 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): ], ) def test_build_ad(model_hub_id: str, llm_extra_args: dict): - if model_hub_id == "meta-llama/Meta-Llama-3.1-8B-Instruct": - pytest.skip("https://nvbugs/5595652") experiment_config = get_small_model_config(model_hub_id, **llm_extra_args) experiment_config["args"]["runtime"] = "demollm" # Default runtime set to demollm experiment_config["args"]["world_size"] = 0 # Default world_size set to 0