diff --git a/examples/auto_deploy/.gitignore b/examples/auto_deploy/.gitignore index 9836a37fc88..36cf5a6dd49 100644 --- a/examples/auto_deploy/.gitignore +++ b/examples/auto_deploy/.gitignore @@ -4,3 +4,4 @@ benchmark_results.json *.png # ignore config files that users might put here for debugging *.yaml +!nano_v3.yaml diff --git a/examples/auto_deploy/nano_v3.yaml b/examples/auto_deploy/nano_v3.yaml new file mode 100644 index 00000000000..411037cc175 --- /dev/null +++ b/examples/auto_deploy/nano_v3.yaml @@ -0,0 +1,23 @@ +runtime: trtllm +compile_backend: torch-cudagraph +max_batch_size: 384 +max_seq_len: 65536 # tunable +enable_chunked_prefill: true +attn_backend: flashinfer +model_factory: AutoModelForCausalLM +skip_loading_weights: false +free_mem_ratio: 0.9 +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384] +kv_cache_config: + # disable kv_cache reuse since not supported for hybrid/ssm models + enable_block_reuse: false +transforms: + detect_sharding: + sharding_source: ['factory', 'heuristic'] + sharding_dims: ['ep', 'bmm'] + # tunable mamba cache dtype + # --> use float32 for accuracy and default (null) for speed + insert_cached_ssm_attention: + cache_config: + # mamba_dtype: float32 + mamba_dtype: null diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 02f7001cff0..0add719af09 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -10,10 +10,10 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union import torch +from pydantic import BaseModel, ConfigDict, Field, field_validator from torch._ops import OpOverloadPacket from torch.fx import Node from torch.types import Number @@ -24,11 +24,39 @@ Constant = Union[int, float, str, None] -@dataclass -class CacheConfig: - """A dataclass to hold information how to configure the cache.""" +class CacheConfig(BaseModel): + """Cache configuration for attention-related dtypes.""" - dtype: Optional[torch.dtype] = None + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.") + mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.") + + @field_validator("dtype", "mamba_dtype", mode="before") + @classmethod + def _coerce_dtype(cls, value): + if value is None or isinstance(value, torch.dtype): + return value + if isinstance(value, str): + dtype = getattr(torch, value, None) + assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}" + return dtype + return value + + def __or__(self, other: "CacheConfig") -> "CacheConfig": + """Combine two CacheConfig objects field-wise using Python's `or` semantics. + + For each field, selects the first non-None value between `self` and `other`. + """ + if not isinstance(other, CacheConfig): + raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}") + merged_kwargs = {} + for field_name in type(self).model_fields.keys(): + merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name) + return CacheConfig(**merged_kwargs) class SequenceInfo: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py index ccd24e7ec00..8be12569ae3 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py @@ -347,6 +347,9 @@ def get_cache_initializers( # Fallback: assume last dim is n_groups * state_size and choose a minimal positive size ssm_state_size = max(1, B_fake.shape[-1]) + # extract ssm_state_dtype from cache_config or hs_fake + ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype + def _get_ssm_cache(si: SequenceInfo): return torch.empty( si.max_batch_size, @@ -354,7 +357,7 @@ def _get_ssm_cache(si: SequenceInfo): head_dim, ssm_state_size, device=si.device, - dtype=cache_config.dtype or hs_fake.dtype, + dtype=ssm_state_dtype, ) return {"ssm_state_cache": _get_ssm_cache} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 64b62419162..1271244ac72 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -125,6 +125,7 @@ def _triton_cached_ssm( dt_limit=(time_step_limit[0], time_step_limit[1]), return_final_states=False, return_varlen_states=True, + mamba_ssm_cache_dtype=ssm_state_cache.dtype, ) y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype) @@ -198,9 +199,7 @@ def _triton_cached_ssm_fake( ) -## Note: we reuse the existing metadata custom op and its registered fake from torch backend. - - +# TODO: consider inheriting from TorchBackendSSM instead of redefining everything @AttentionRegistry.register("triton_ssm") class TritonBackendSSM(AttentionDescriptor): @classmethod @@ -247,6 +246,9 @@ def get_cache_initializers( else: ssm_state_size = max(1, B_fake.shape[-1]) + # extract ssm_state_dtype from cache_config or hs_fake + ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype + def _get_ssm_cache(si: SequenceInfo): return torch.empty( si.max_batch_size, @@ -254,7 +256,7 @@ def _get_ssm_cache(si: SequenceInfo): head_dim, ssm_state_size, device=si.device, - dtype=cache_config.dtype or hs_fake.dtype, + dtype=ssm_state_dtype, ) return {"ssm_state_cache": _get_ssm_cache} diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 5a2b8485d6b..ecf42d0b238 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -8,7 +8,12 @@ from pydantic import Field from torch.fx import GraphModule, Node -from ...custom_ops.attention_interface import AttentionDescriptor, AttentionRegistry, Constant +from ...custom_ops.attention_interface import ( + AttentionDescriptor, + AttentionRegistry, + CacheConfig, + Constant, +) from ...distributed.common import all_gather_object, get_world_size from ...distributed.common import is_initialized as is_distributed_initialized from ...models.factory import ModelFactory @@ -66,6 +71,9 @@ class InsertCachedAttentionConfig(TransformConfig): """Configuration for the insert cached attention transform.""" backend: Optional[str] = Field(default=None, description="The attention backend to use.") + cache_config: CacheConfig = Field( + default_factory=CacheConfig, description="The custom cache configuration to use." + ) @TransformRegistry.register("insert_cached_attention") @@ -137,7 +145,9 @@ def _apply( """Replace uncached source attention node with corresponding cached attn node.""" attn_descriptor = self.attn_descriptor - cache_config = factory.get_cache_config() + # run field-wise or to combine the cache config from the transform and the factory + # the transform config takes precedence over the factory config + cache_config = self.config.cache_config | factory.get_cache_config() # Get all attention nodes and their info objects source_op = attn_descriptor.get_source_attention_op()