diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index 3e81b22a099..7820c7267c1 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -719,6 +719,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars) "tensorrt_llm/_torch/pyexecutor/_util.py", "tensorrt_llm/_torch/pyexecutor/model_engine.py", "tensorrt_llm/_torch/pyexecutor/py_executor.py", + "tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py", "tensorrt_llm/evaluate/json_mode_eval.py", "tensorrt_llm/evaluate/mmlu.py", "tensorrt_llm/executor/", @@ -740,6 +741,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars) "tests/integration/defs/accuracy/test_disaggregated_serving.py", "tests/unittest/_torch/ray_orchestrator/multi_gpu/", "tests/integration/defs/examples/test_ray.py", + "tests/integration/defs/accuracy/test_llm_api_autodeploy.py", "tests/unittest/llmapi/test_async_llm.py", ] diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 03eb4a07030..987b85277d1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -267,7 +267,7 @@ class WeightShardingInfo(ShardingTransformInfo): min_local_shape: int = 1 layer_type: LayerType = LayerType.MLP # used for TP sharding of fused weights - fused_weight_dims: Optional[list] = None + fused_weight_dims: Optional[tuple] = None def quantization_cb( self, @@ -1316,7 +1316,7 @@ def _shard_parameter_node( config: ShardingTransformConfig, add_dist: bool = False, min_local_shape: int = 1, - fused_weight_dims: Optional[list] = None, + fused_weight_dims: Optional[tuple] = None, quantization_cb: Optional[ Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] ] = None, @@ -1835,7 +1835,7 @@ def _process_ssm_sharding( config=config, dist_op=None, min_local_shape=1, - fused_weight_dims=fused_weight_dims["in_proj"], + fused_weight_dims=tuple(fused_weight_dims["in_proj"]), layer_type=LayerType.SSM, ) ): @@ -1904,7 +1904,7 @@ def _process_ssm_sharding( fused_dims = None for k, v in fused_weight_dims.items(): if k in weight_key: - fused_dims = v + fused_dims = tuple(v) break # Shard the weight tensor (also updates the parameter in the module) @@ -2089,7 +2089,7 @@ def _determine_fused_weight_dims( ad_logger.warning( f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping." ) - return + return None chunk_nodes = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk)) if len(chunk_nodes) > 0: assert len(linear_nodes) == 1 @@ -2098,6 +2098,8 @@ def _determine_fused_weight_dims( num_chunks = chunk_nodes[0].args[1] weight_dim = shape(linear_node)[2] fused_weight_dims = [weight_dim // num_chunks] * num_chunks + if fused_weight_dims is not None: + fused_weight_dims = tuple(fused_weight_dims) return fused_weight_dims diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index c8adaa96849..fad7023e912 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -218,11 +218,13 @@ def test_bf16(self): task.evaluate(llm) @pytest.mark.skip_less_device_memory(32000) - def test_fp8(self): + @pytest.mark.parametrize("world_size", [1, 4]) + def test_fp8(self, world_size): kwargs = self.get_default_kwargs() sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_PATH_FP8, tokenizer=self.MODEL_PATH_FP8, + world_size=world_size, **kwargs) as llm: # Manually set quant_config for FP8 model to get the accuracy threshold llm.args.quant_config.quant_algo = QuantAlgo.FP8 diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 98860efd0b8..872c6f0ae59 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -95,7 +95,8 @@ l0_b200: - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_deepgemm[enable_configurable_moe-dtype1-72-256-2560-DefaultMoeRoutingMethod] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1] - - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8 + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1] + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[4] - unittest/_torch/auto_deploy/unit/singlegpu - condition: ranges: diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index f29c120bd2b..efa4e381291 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -117,7 +117,8 @@ l0_h100: - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[True-1] - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[False] - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True] - - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8 + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1] + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[4] - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16 - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target] - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3] diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index b4f82edcfae..df4a07e34ca 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -1,6 +1,7 @@ """Tests for basic graph sharding.""" from functools import partial +from types import SimpleNamespace from typing import Type import pytest @@ -13,6 +14,7 @@ import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_nemotron_h import NemotronHMamba2Mixer from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ( FP8WeightShardingInfo, LayerType, @@ -35,6 +37,14 @@ "linear1": "colwise", "linear2": "rowwise", "linear": "gather", + # Mamba2 specific projections + "in_proj": "mamba", + "out_proj": "rowwise", + # MLA specific projections + "q_a_proj": "gather", + "q_b_proj": "colwise", + "kv_a_proj_with_mqa": "gather", + "kv_b_proj": "colwise", # "input_layernorm.weight": "sequence_parallel", # "post_attention_layernorm.weight": "sequence_parallel", # "norm.weight": "sequence_parallel", @@ -50,7 +60,6 @@ } predefined_config = { - "head_dim": 8, "tp_plan": base_model_tp_plan, } @@ -125,6 +134,75 @@ def forward(self, x): return self.linear2(y) +class MLA_Block(nn.Module): + """Multi-Latent Attention block - simplified standalone version. + + Based on DeepSeek MLA architecture with KV compression. + This is a minimal, self-contained implementation for testing sharding patterns. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + q_lora_rank: int, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + bias: bool = False, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.kv_lora_rank = kv_lora_rank + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + + # KV compression path (not sharded - gather) + self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=bias) + + # KV decompression (sharded column-wise) + self.kv_b_proj = nn.Linear( + kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim), bias=False + ) + + # Query path (sharded column-wise) + self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=bias) + self.q_b_proj = nn.Linear(q_lora_rank, num_heads * self.qk_head_dim, bias=bias) + self.q_a_layernorm = nn.LayerNorm(q_lora_rank) + # Output projection (sharded row-wise) + self.o_proj = nn.Linear(num_heads * v_head_dim, hidden_size, bias=bias) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s, _ = x.shape + + # Compress KV to latent + compressed_kv_rope = self.kv_a_proj_with_mqa(x) # (b, s, kv_lora_rank + rope_dim) + compressed_kv = compressed_kv_rope[:, :, : self.kv_lora_rank] # (b, s, kv_lora_rank) + + # Decompress to full K and V + kv = self.kv_b_proj(compressed_kv) # (b, s, num_heads * (qk_nope + v)) + k_nope_v = kv.view(b, s, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope = k_nope_v[:, :, :, : self.qk_nope_head_dim] + v = k_nope_v[:, :, :, self.qk_nope_head_dim :] + + # Query projection + # q = q_b_proj @ (layernorm(q_a_proj @ x)) + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) # (b, s, num_heads * qk_head_dim) + q = q.view(b, s, self.num_heads, self.qk_head_dim) + q_nope = q[:, :, :, : self.qk_nope_head_dim] + + attn_out = torch.ops.auto_deploy.torch_attention(q_nope, k_nope, v, is_causal=True) + attn_out = attn_out.contiguous().view(b, s, -1) + # Output projection + output = self.o_proj(attn_out) + return output + + def _run_sharding_execution_job( model_cls: nn.Module, dist_op_expected: str, @@ -137,6 +215,7 @@ def _run_sharding_execution_job( batch_size = 4 sequence_len = 8 num_features = 32 + skip_output_assert = False # GQA specific parameters num_heads = 4 @@ -150,6 +229,54 @@ def _run_sharding_execution_job( ).to(device="cuda", dtype=torch.float16) elif model_cls == FP8MLP: model = model_cls(num_features, num_features, bias=bias).to("cuda") + elif model_cls == NemotronHMamba2Mixer: + # Create config for Mamba2 based on Nemotron models + # Scaled down from typical values: hidden_size=5120, ssm_state_size=128 + mamba_config = SimpleNamespace( + hidden_size=num_features, + ssm_state_size=16, # Scaled from 128 + mamba_num_heads=num_heads, + mamba_head_dim=num_features // num_heads, # 8 + n_groups=1, # Typical value + chunk_size=256, + conv_kernel=4, + use_conv_bias=bias, + use_bias=bias, + mamba_hidden_act="silu", + layer_norm_epsilon=1e-5, + time_step_limit=(0.0, float("inf")), + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + initializer_range=0.02, + rescale_prenorm_residual=False, + residual_in_fp32=False, + num_hidden_layers=1, + ) + model = model_cls(mamba_config, layer_idx=0).to(device="cuda", dtype=torch.float16) + elif model_cls == MLA_Block: + # Use actual DeepSeek-V3/R1 production values + # From HuggingFace config (HunYuanPretrainedConfig defaults): + # hidden_size=4096, num_attention_heads=32 + # kv_lora_rank=512, q_lora_rank=1536 + # qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128 + num_heads_mla = 16 + qk_nope_head_dim = 64 + qk_rope_head_dim = 32 + v_head_dim = 64 + kv_lora_rank = 256 + skip_output_assert = True + + model = model_cls( + hidden_size=num_features, + num_heads=num_heads_mla, + q_lora_rank=kv_lora_rank, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + bias=bias, + ).to(device="cuda", dtype=torch.float16) else: model = model_cls(num_features, num_features, bias=bias).to( device="cuda", dtype=torch.float16 @@ -178,6 +305,11 @@ def _get_expected_num_params(num_p_og: int) -> int: num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size else: num_params = num_p_og // world_size + num_update + if model_cls == MLA_Block: + # since q_a_proj is simple-sharded and followed by q_a_layernorm, the layernorm params + # are NOT sharded - they have to be replicated. To account for this, we need to add the + # number of parameters of the layernorm (weight and bias)to the number of parameters of the model. + num_params += 2 * kv_lora_rank * (world_size - 1) // world_size return num_params def verify_local_weight_sizes(gm) -> bool: @@ -223,6 +355,7 @@ def combined_graph_check(gm) -> bool: gm_transformed, check_transformed_graph=combined_graph_check, _get_expected_num_params=_get_expected_num_params, + skip_output_assert=skip_output_assert, ) @@ -248,6 +381,47 @@ def _run_pattern_detection_job( hidden_size=num_features, num_key_value_heads=num_key_value_heads, ).to(device="cuda", dtype=torch.float16) + elif model_cls == NemotronHMamba2Mixer: + # Create config for Mamba2 + mamba_config = SimpleNamespace( + hidden_size=num_features, + ssm_state_size=16, + mamba_num_heads=num_heads, + mamba_head_dim=num_features // num_heads, + n_groups=1, + chunk_size=256, + conv_kernel=4, + use_conv_bias=bias, + use_bias=bias, + mamba_hidden_act="silu", + layer_norm_epsilon=1e-5, + time_step_limit=(0.0, float("inf")), + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + initializer_range=0.02, + rescale_prenorm_residual=False, + residual_in_fp32=False, + num_hidden_layers=1, + ) + model = model_cls(mamba_config, layer_idx=0).to(device="cuda", dtype=torch.float16) + elif model_cls == MLA_Block: + # Create simplified MLA based on DeepSeek-V3 architecture + qk_nope_head_dim = 2 + qk_rope_head_dim = 1 + v_head_dim = 2 + kv_lora_rank = 8 + + model = model_cls( + hidden_size=num_features, + num_heads=num_heads, + q_lora_rank=kv_lora_rank, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + bias=bias, + ).to(device="cuda", dtype=torch.float16) else: model = model_cls(num_features, num_features, bias=bias).to( device="cuda", dtype=torch.float16 @@ -344,6 +518,96 @@ def _run_pattern_detection_job( min_local_shape=1, ) ) + elif model_cls == NemotronHMamba2Mixer: + for node in gm.graph.nodes: + if is_linear_op(node): + # in_proj should be sharded column-wise + # out_proj should be sharded row-wise with all_reduce + if "out_proj" in node.args[1].name: + dim = SplitDimension.ROW + dist_op = "all_reduce" + fused_weight_dims = None + else: + dim = SplitDimension.COLUMN + dist_op = None + fused_weight_dims = (num_features, num_features, 16, 16, num_heads) + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=dim, + config=config, + dist_op=dist_op, + min_local_shape=1, + layer_type=LayerType.SSM, + fused_weight_dims=fused_weight_dims, + ) + ) + if is_op(node, torch.ops.auto_deploy.torch_causal_conv1d): + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=SplitDimension.COLUMN, + config=config, + dist_op=None, + min_local_shape=1, + layer_type=LayerType.SSM, + fused_weight_dims=(num_features, 16, 16), + ) + ) + if is_op(node, torch.ops.auto_deploy.torch_ssm): + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=SplitDimension.COLUMN, + config=config, + dist_op=None, + min_local_shape=1, + layer_type=LayerType.SSM, + fused_weight_dims=None, + ) + ) + if len(node.args) > 1 and "norm_weight" in node.args[0].name: + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=SplitDimension.COLUMN, + config=config, + dist_op=None, + min_local_shape=1, + layer_type=LayerType.SSM, + fused_weight_dims=None, + ) + ) + elif model_cls == MLA_Block: + for node in gm.graph.nodes: + if is_linear_op(node): + # kv_a_proj_with_mqa: gather (no sharding) + # q_b_proj/kv_b_proj: column-wise + # o_proj: row-wise with all_reduce + min_local_shape = 2 + if "o_proj" in node.args[1].name: + dim = SplitDimension.ROW + dist_op = "all_reduce" + elif ( + "kv_a_proj_with_mqa" in node.args[1].name or "q_a_proj" in node.args[1].name + ): + # This is simple-shard gather + dim = SplitDimension.COLUMN + dist_op = "all_gather" + min_local_shape = 1 + else: + dim = SplitDimension.COLUMN + dist_op = None + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=dim, + config=config, + dist_op=dist_op, + min_local_shape=min_local_shape, + layer_type=LayerType.MLA, + ) + ) # get detected transformations optimizer = InferenceOptimizer( @@ -378,6 +642,8 @@ def _run_pattern_detection_job( (FP8MLP, "torch_dist_all_reduce"), (nn.Linear, "torch_dist_all_gather"), (GQA_Block, "torch_dist_all_reduce"), + (NemotronHMamba2Mixer, "torch_dist_all_reduce"), + (MLA_Block, "torch_dist_all_reduce"), ), ) def test_sharding( @@ -403,6 +669,8 @@ def test_sharding( (FP8MLP, "torch_dist_all_reduce"), (nn.Linear, "torch_dist_all_gather"), (GQA_Block, "torch_dist_all_reduce"), + (NemotronHMamba2Mixer, "torch_dist_all_reduce"), + (MLA_Block, "torch_dist_all_reduce"), ), ) def test_sharding_pattern_detection(