From 0986426c267bdcf0aa312f1990c5b3819273916f Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 31 Dec 2025 04:44:22 -0800 Subject: [PATCH 01/18] Added Mamba and MLA layers to the sharding tests Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- jenkins/L0_MergeRequest.groovy | 2 + .../library/test_tp_sharding.py | 233 +++++++++++++++++- 2 files changed, 234 insertions(+), 1 deletion(-) diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index 3e81b22a099..7820c7267c1 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -719,6 +719,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars) "tensorrt_llm/_torch/pyexecutor/_util.py", "tensorrt_llm/_torch/pyexecutor/model_engine.py", "tensorrt_llm/_torch/pyexecutor/py_executor.py", + "tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py", "tensorrt_llm/evaluate/json_mode_eval.py", "tensorrt_llm/evaluate/mmlu.py", "tensorrt_llm/executor/", @@ -740,6 +741,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars) "tests/integration/defs/accuracy/test_disaggregated_serving.py", "tests/unittest/_torch/ray_orchestrator/multi_gpu/", "tests/integration/defs/examples/test_ray.py", + "tests/integration/defs/accuracy/test_llm_api_autodeploy.py", "tests/unittest/llmapi/test_async_llm.py", ] diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index b4f82edcfae..3225a873237 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -1,6 +1,7 @@ """Tests for basic graph sharding.""" from functools import partial +from types import SimpleNamespace from typing import Type import pytest @@ -13,6 +14,7 @@ import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_nemotron_h import NemotronHMamba2Mixer from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ( FP8WeightShardingInfo, LayerType, @@ -35,6 +37,14 @@ "linear1": "colwise", "linear2": "rowwise", "linear": "gather", + # Mamba2 specific projections + "in_proj": "mamba", + "out_proj": "rowwise", + # MLA specific projections + "q_a_proj": "gather", + "q_b_proj": "colwise", + "kv_a_proj_with_mqa": "gather", + "kv_b_proj": "colwise", # "input_layernorm.weight": "sequence_parallel", # "post_attention_layernorm.weight": "sequence_parallel", # "norm.weight": "sequence_parallel", @@ -50,7 +60,6 @@ } predefined_config = { - "head_dim": 8, "tp_plan": base_model_tp_plan, } @@ -125,6 +134,85 @@ def forward(self, x): return self.linear2(y) +class MLA_Block(nn.Module): + """Multi-Latent Attention block - simplified standalone version. + + Based on DeepSeek MLA architecture with KV compression. + This is a minimal, self-contained implementation for testing sharding patterns. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + q_lora_rank: int, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + bias: bool = False, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.kv_lora_rank = kv_lora_rank + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + + # KV compression path (not sharded - gather) + self.kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + qk_rope_head_dim, bias=bias) + + # KV decompression (sharded column-wise) + self.kv_b_proj = nn.Linear( + kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim), bias=False + ) + + # Query path (sharded column-wise) + self.q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=bias) + self.q_b_proj = nn.Linear(q_lora_rank, num_heads * self.qk_head_dim, bias=bias) + self.q_a_layernorm = nn.LayerNorm(q_lora_rank) + # Output projection (sharded row-wise) + self.o_proj = nn.Linear(num_heads * v_head_dim, hidden_size, bias=bias) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s, _ = x.shape + + # Compress KV to latent + compressed_kv_rope = self.kv_a_proj_with_mqa(x) # (b, s, kv_lora_rank + rope_dim) + compressed_kv = compressed_kv_rope[:, :, : self.kv_lora_rank] # (b, s, kv_lora_rank) + + # Decompress to full K and V + kv = self.kv_b_proj(compressed_kv) # (b, s, num_heads * (qk_nope + v)) + k_nope_v = kv.view(b, s, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope = k_nope_v[:, :, :, : self.qk_nope_head_dim] + v = k_nope_v[:, :, :, self.qk_nope_head_dim :] + + # Query projection + # q = q_b_proj @ (layernorm(q_a_proj @ x)) + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) # (b, s, num_heads * qk_head_dim) + q = q.view(b, s, self.num_heads, self.qk_head_dim) + q_nope = q[:, :, :, : self.qk_nope_head_dim] + + # Simplified attention: just use nope parts + # In real MLA, this would include rope and proper attention + # attn_out = torch.matmul( + # q_nope, k_nope.transpose(-2, -1) + # ) @ v # Simplified attention pattern + + # attn_out = attn_out.contiguous().view(b, s, self.num_heads * self.v_head_dim) + + attn_out = torch.ops.auto_deploy.torch_attention( + q_nope, k_nope, v, is_causal=True, layout="bsnd" + ) + attn_out = attn_out.contiguous().view(b, s, -1) + # Output projection + output = self.o_proj(attn_out) + return output + + def _run_sharding_execution_job( model_cls: nn.Module, dist_op_expected: str, @@ -150,6 +238,53 @@ def _run_sharding_execution_job( ).to(device="cuda", dtype=torch.float16) elif model_cls == FP8MLP: model = model_cls(num_features, num_features, bias=bias).to("cuda") + elif model_cls == NemotronHMamba2Mixer: + # Create config for Mamba2 based on Nemotron models + # Scaled down from typical values: hidden_size=5120, ssm_state_size=128 + mamba_config = SimpleNamespace( + hidden_size=num_features, + ssm_state_size=16, # Scaled from 128 + mamba_num_heads=num_heads, + mamba_head_dim=num_features // num_heads, # 8 + n_groups=1, # Typical value + chunk_size=256, + conv_kernel=4, + use_conv_bias=bias, + use_bias=bias, + mamba_hidden_act="silu", + layer_norm_epsilon=1e-5, + time_step_limit=(0.0, float("inf")), + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + initializer_range=0.02, + rescale_prenorm_residual=False, + residual_in_fp32=False, + num_hidden_layers=1, + ) + model = model_cls(mamba_config, layer_idx=0).to(device="cuda", dtype=torch.float16) + elif model_cls == MLA_Block: + # Use actual DeepSeek-V3/R1 production values + # From HuggingFace config (HunYuanPretrainedConfig defaults): + # hidden_size=4096, num_attention_heads=32 + # kv_lora_rank=512, q_lora_rank=1536 + # qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128 + num_heads_mla = 16 + qk_nope_head_dim = 64 + qk_rope_head_dim = 32 + v_head_dim = 64 + kv_lora_rank = 256 + + model = model_cls( + hidden_size=num_features, + num_heads=num_heads_mla, + q_lora_rank=kv_lora_rank, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + bias=bias, + ).to(device="cuda", dtype=torch.float16) else: model = model_cls(num_features, num_features, bias=bias).to( device="cuda", dtype=torch.float16 @@ -178,6 +313,11 @@ def _get_expected_num_params(num_p_og: int) -> int: num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size else: num_params = num_p_og // world_size + num_update + if model_cls == MLA_Block: + # since q_a_proj is simple-sharded and followed by q_a_layernorm, the layernorm params + # are NOT sharded - they have to be replicated. To account for this, we need to add the + # number of parameters of the layernorm (weight and bias)to the number of parameters of the model. + num_params += 2 * kv_lora_rank * (world_size - 1) // world_size return num_params def verify_local_weight_sizes(gm) -> bool: @@ -248,6 +388,47 @@ def _run_pattern_detection_job( hidden_size=num_features, num_key_value_heads=num_key_value_heads, ).to(device="cuda", dtype=torch.float16) + elif model_cls == NemotronHMamba2Mixer: + # Create config for Mamba2 + mamba_config = SimpleNamespace( + hidden_size=num_features, + ssm_state_size=16, + mamba_num_heads=num_heads, + mamba_head_dim=num_features // num_heads, + n_groups=1, + chunk_size=256, + conv_kernel=4, + use_conv_bias=bias, + use_bias=bias, + mamba_hidden_act="silu", + layer_norm_epsilon=1e-5, + time_step_limit=(0.0, float("inf")), + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + initializer_range=0.02, + rescale_prenorm_residual=False, + residual_in_fp32=False, + num_hidden_layers=1, + ) + model = model_cls(mamba_config, layer_idx=0).to(device="cuda", dtype=torch.float16) + elif model_cls == MLA_Block: + # Create simplified MLA based on DeepSeek-V3 architecture + qk_nope_head_dim = 2 + qk_rope_head_dim = 1 + v_head_dim = 2 + kv_lora_rank = 8 + + model = model_cls( + hidden_size=num_features, + num_heads=num_heads, + q_lora_rank=kv_lora_rank, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + bias=bias, + ).to(device="cuda", dtype=torch.float16) else: model = model_cls(num_features, num_features, bias=bias).to( device="cuda", dtype=torch.float16 @@ -344,6 +525,52 @@ def _run_pattern_detection_job( min_local_shape=1, ) ) + elif model_cls == NemotronHMamba2Mixer: + for node in gm.graph.nodes: + if is_linear_op(node): + # in_proj should be sharded column-wise + # out_proj should be sharded row-wise with all_reduce + if "out_proj" in node.args[1].name: + dim = SplitDimension.ROW + dist_op = "all_reduce" + else: + dim = SplitDimension.COLUMN + dist_op = None + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=dim, + config=config, + dist_op=dist_op, + min_local_shape=1, + layer_type=LayerType.MLP, + ) + ) + elif model_cls == MLA_Block: + for node in gm.graph.nodes: + if is_linear_op(node): + # kv_a_proj_with_mqa: gather (no sharding) + # q_b_proj/kv_b_proj: column-wise + # o_proj: row-wise with all_reduce + if "o_proj" in node.args[1].name: + dim = SplitDimension.ROW + dist_op = "all_reduce" + elif "kv_a_proj_with_mqa" in node.args[1].name: + # This is gather, skip sharding + continue + else: + dim = SplitDimension.COLUMN + dist_op = None + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=dim, + config=config, + dist_op=dist_op, + min_local_shape=1, + layer_type=LayerType.ATTENTION, + ) + ) # get detected transformations optimizer = InferenceOptimizer( @@ -378,6 +605,8 @@ def _run_pattern_detection_job( (FP8MLP, "torch_dist_all_reduce"), (nn.Linear, "torch_dist_all_gather"), (GQA_Block, "torch_dist_all_reduce"), + (NemotronHMamba2Mixer, "torch_dist_all_reduce"), + (MLA_Block, "torch_dist_all_reduce"), ), ) def test_sharding( @@ -403,6 +632,8 @@ def test_sharding( (FP8MLP, "torch_dist_all_reduce"), (nn.Linear, "torch_dist_all_gather"), (GQA_Block, "torch_dist_all_reduce"), + (NemotronHMamba2Mixer, "torch_dist_all_reduce"), + (MLA_Block, "torch_dist_all_reduce"), ), ) def test_sharding_pattern_detection( From 683e77c8608208c2adbf108530bd2e659754ff72 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 31 Dec 2025 06:21:28 -0800 Subject: [PATCH 02/18] added integration tests Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../auto_deploy/transform/library/sharding.py | 12 ++-- .../defs/accuracy/test_llm_api_autodeploy.py | 4 +- .../test_lists/test-db/l0_h100.yml | 3 +- .../library/test_tp_sharding.py | 60 ++++++++++++++++--- 4 files changed, 62 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 03eb4a07030..f1c0caf38f6 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -267,7 +267,7 @@ class WeightShardingInfo(ShardingTransformInfo): min_local_shape: int = 1 layer_type: LayerType = LayerType.MLP # used for TP sharding of fused weights - fused_weight_dims: Optional[list] = None + fused_weight_dims: Optional[tuple] = None def quantization_cb( self, @@ -1316,7 +1316,7 @@ def _shard_parameter_node( config: ShardingTransformConfig, add_dist: bool = False, min_local_shape: int = 1, - fused_weight_dims: Optional[list] = None, + fused_weight_dims: Optional[tuple] = None, quantization_cb: Optional[ Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] ] = None, @@ -1835,7 +1835,7 @@ def _process_ssm_sharding( config=config, dist_op=None, min_local_shape=1, - fused_weight_dims=fused_weight_dims["in_proj"], + fused_weight_dims=tuple(fused_weight_dims["in_proj"]), layer_type=LayerType.SSM, ) ): @@ -1904,7 +1904,7 @@ def _process_ssm_sharding( fused_dims = None for k, v in fused_weight_dims.items(): if k in weight_key: - fused_dims = v + fused_dims = tuple(v) break # Shard the weight tensor (also updates the parameter in the module) @@ -2089,7 +2089,7 @@ def _determine_fused_weight_dims( ad_logger.warning( f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping." ) - return + return None chunk_nodes = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk)) if len(chunk_nodes) > 0: assert len(linear_nodes) == 1 @@ -2098,7 +2098,7 @@ def _determine_fused_weight_dims( num_chunks = chunk_nodes[0].args[1] weight_dim = shape(linear_node)[2] fused_weight_dims = [weight_dim // num_chunks] * num_chunks - return fused_weight_dims + return tuple(fused_weight_dims) def _process_column_sharding( diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index c8adaa96849..fad7023e912 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -218,11 +218,13 @@ def test_bf16(self): task.evaluate(llm) @pytest.mark.skip_less_device_memory(32000) - def test_fp8(self): + @pytest.mark.parametrize("world_size", [1, 4]) + def test_fp8(self, world_size): kwargs = self.get_default_kwargs() sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_PATH_FP8, tokenizer=self.MODEL_PATH_FP8, + world_size=world_size, **kwargs) as llm: # Manually set quant_config for FP8 model to get the accuracy threshold llm.args.quant_config.quant_algo = QuantAlgo.FP8 diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index f29c120bd2b..efa4e381291 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -117,7 +117,8 @@ l0_h100: - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[True-1] - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[False] - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True] - - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8 + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1] + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[4] - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16 - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target] - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3] diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 3225a873237..a790c4b1d9d 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -204,9 +204,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # attn_out = attn_out.contiguous().view(b, s, self.num_heads * self.v_head_dim) - attn_out = torch.ops.auto_deploy.torch_attention( - q_nope, k_nope, v, is_causal=True, layout="bsnd" - ) + attn_out = torch.ops.auto_deploy.torch_attention(q_nope, k_nope, v, is_causal=True) attn_out = attn_out.contiguous().view(b, s, -1) # Output projection output = self.o_proj(attn_out) @@ -533,9 +531,11 @@ def _run_pattern_detection_job( if "out_proj" in node.args[1].name: dim = SplitDimension.ROW dist_op = "all_reduce" + fused_weight_dims = None else: dim = SplitDimension.COLUMN dist_op = None + fused_weight_dims = (num_features, num_features, 16, 16, num_heads) expected_transformations.append( WeightShardingInfo( target_node=node.name, @@ -543,7 +543,44 @@ def _run_pattern_detection_job( config=config, dist_op=dist_op, min_local_shape=1, - layer_type=LayerType.MLP, + layer_type=LayerType.SSM, + fused_weight_dims=fused_weight_dims, + ) + ) + if is_op(node, torch.ops.auto_deploy.torch_causal_conv1d): + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=SplitDimension.COLUMN, + config=config, + dist_op=None, + min_local_shape=1, + layer_type=LayerType.SSM, + fused_weight_dims=(num_features, 16, 16), + ) + ) + if is_op(node, torch.ops.auto_deploy.torch_ssm): + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=SplitDimension.COLUMN, + config=config, + dist_op=None, + min_local_shape=1, + layer_type=LayerType.SSM, + fused_weight_dims=None, + ) + ) + if len(node.args) > 1 and "norm_weight" in node.args[0].name: + expected_transformations.append( + WeightShardingInfo( + target_node=node.name, + split_dim=SplitDimension.COLUMN, + config=config, + dist_op=None, + min_local_shape=1, + layer_type=LayerType.SSM, + fused_weight_dims=None, ) ) elif model_cls == MLA_Block: @@ -552,12 +589,17 @@ def _run_pattern_detection_job( # kv_a_proj_with_mqa: gather (no sharding) # q_b_proj/kv_b_proj: column-wise # o_proj: row-wise with all_reduce + min_local_shape = 2 if "o_proj" in node.args[1].name: dim = SplitDimension.ROW dist_op = "all_reduce" - elif "kv_a_proj_with_mqa" in node.args[1].name: - # This is gather, skip sharding - continue + elif ( + "kv_a_proj_with_mqa" in node.args[1].name or "q_a_proj" in node.args[1].name + ): + # This is simple-shard gather + dim = SplitDimension.COLUMN + dist_op = "all_gather" + min_local_shape = 1 else: dim = SplitDimension.COLUMN dist_op = None @@ -567,8 +609,8 @@ def _run_pattern_detection_job( split_dim=dim, config=config, dist_op=dist_op, - min_local_shape=1, - layer_type=LayerType.ATTENTION, + min_local_shape=min_local_shape, + layer_type=LayerType.MLA, ) ) From 9a34541ddf3946c7d8e101194602524dfb4946da Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 31 Dec 2025 06:26:57 -0800 Subject: [PATCH 03/18] Fixed fused weights Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index f1c0caf38f6..987b85277d1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -2098,7 +2098,9 @@ def _determine_fused_weight_dims( num_chunks = chunk_nodes[0].args[1] weight_dim = shape(linear_node)[2] fused_weight_dims = [weight_dim // num_chunks] * num_chunks - return tuple(fused_weight_dims) + if fused_weight_dims is not None: + fused_weight_dims = tuple(fused_weight_dims) + return fused_weight_dims def _process_column_sharding( From 7f5edfbe1c42c93184ba6bc5dd3d4e8f13387ef8 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 31 Dec 2025 06:53:47 -0800 Subject: [PATCH 04/18] code cleanup Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../transformations/library/test_tp_sharding.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index a790c4b1d9d..df4a07e34ca 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -196,14 +196,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: q = q.view(b, s, self.num_heads, self.qk_head_dim) q_nope = q[:, :, :, : self.qk_nope_head_dim] - # Simplified attention: just use nope parts - # In real MLA, this would include rope and proper attention - # attn_out = torch.matmul( - # q_nope, k_nope.transpose(-2, -1) - # ) @ v # Simplified attention pattern - - # attn_out = attn_out.contiguous().view(b, s, self.num_heads * self.v_head_dim) - attn_out = torch.ops.auto_deploy.torch_attention(q_nope, k_nope, v, is_causal=True) attn_out = attn_out.contiguous().view(b, s, -1) # Output projection @@ -223,6 +215,7 @@ def _run_sharding_execution_job( batch_size = 4 sequence_len = 8 num_features = 32 + skip_output_assert = False # GQA specific parameters num_heads = 4 @@ -272,6 +265,7 @@ def _run_sharding_execution_job( qk_rope_head_dim = 32 v_head_dim = 64 kv_lora_rank = 256 + skip_output_assert = True model = model_cls( hidden_size=num_features, @@ -361,6 +355,7 @@ def combined_graph_check(gm) -> bool: gm_transformed, check_transformed_graph=combined_graph_check, _get_expected_num_params=_get_expected_num_params, + skip_output_assert=skip_output_assert, ) From 9ad64d348f65bb839e1c5dd1795c764e9aecb386 Mon Sep 17 00:00:00 2001 From: Lucas <11156568+lucaslie@users.noreply.github.com> Date: Tue, 30 Dec 2025 10:16:03 -0500 Subject: [PATCH 05/18] proposal for A_minus as parameter Signed-off-by: Lucas <11156568+lucaslie@users.noreply.github.com> --- .../models/custom/modeling_nemotron_h.py | 26 ++++++++++--------- .../models/test_modeling_nemotron_h.py | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py index 15178b00f1c..23c79c55e44 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py @@ -111,13 +111,12 @@ def __init__(self, config, layer_idx: int): # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.num_heads + 1) - self.A_log = nn.Parameter(torch.log(A)) - self.A_log._no_weight_decay = True - # Instead of recomputing `torch.exp(self.A_log.float())` on every forward pass, we will register a hook + # self.A_log = nn.Parameter(torch.log(A)) + # self.A_log._no_weight_decay = True + # Instead of recomputing `-torch.exp(self.A_log.float())` on every forward pass, we will register a hook # that sets this appropriately when loading weights. - # NOTE: we explicitly register this as a non-persistent buffer so that it does not appear in the state dict of - # this module, or an equivalent graph module trace from it, but still gets included in e.g. `to()` calls. - self.register_buffer("_minus_A", -A.float(), persistent=False) + self.A_minus = nn.Parameter(-A.float()) + self.A_minus._no_weight_decay = True self.norm = MambaRMSNormGated( self.intermediate_size, eps=self.layer_norm_epsilon, @@ -129,7 +128,7 @@ def __init__(self, config, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias - self.register_load_state_dict_post_hook(self._load_state_dict_post_hook) + self.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) def torch_forward(self, input_states): batch_size, seq_len, _ = input_states.shape @@ -166,10 +165,9 @@ def torch_forward(self, input_states): ) # 3. SSM transformation - A = self._minus_A y = torch.ops.auto_deploy.torch_ssm( hidden_states=hidden_states.view(batch_size, seq_len, -1, self.head_dim), - A=A, + A=self.A_minus, B=B.view(batch_size, seq_len, -1, self.ssm_state_size), C=C.view(batch_size, seq_len, -1, self.ssm_state_size), D=self.D, @@ -194,8 +192,12 @@ def forward(self, hidden_states): return self.torch_forward(hidden_states) @staticmethod - def _load_state_dict_post_hook(module, incompatible_keys) -> None: - module._minus_A.data = -torch.exp(module.A_log.float()) + def _load_state_dict_pre_hook(module, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) -> None: + A_log_key = prefix + "A_log" + A_minus_key = prefix + "A_minus" + if A_log_key in state_dict: + state_dict[A_minus_key] = -torch.exp(state_dict.pop(A_log_key).float()) class NemotronHRMSNorm(nn.Module): @@ -466,7 +468,7 @@ class NemotronHPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, NemotronHMamba2Mixer): - module.A_log._no_weight_decay = True + module.A_minus._no_weight_decay = True module.D._no_weight_decay = True dt = torch.exp( diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_modeling_nemotron_h.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_modeling_nemotron_h.py index 94b22ed14fc..a93537ebc12 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_modeling_nemotron_h.py @@ -205,7 +205,7 @@ def _run_torch_export_to_gm(): factory.load_or_random_init(gm, device="cuda") move_to_device(gm, "cuda") factory._to_maybe_random(model, "cuda") - # In order to ensure the `_minus_A` (non-persistent buffer) is correct, we need to run the + # In order to ensure the `A_minus` (non-persistent buffer) is correct, we need to run the # model's load state pre/post hooks by loading the state dicts after initialization. # NOTE: this is done under the hood by `torch_export_to_gm`, so we only need this in this # `if` clause. From f117b8e3a124479a8965105bd0658c4f6146fd9b Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Tue, 23 Dec 2025 05:19:46 -0800 Subject: [PATCH 06/18] wip Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../auto_deploy/transform/library/sharding.py | 3 +++ tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 987b85277d1..64e25ab57f1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -1344,6 +1344,7 @@ def _shard_parameter_node( # Shard weight using the unified function (also updates the parameter) original_weight = gm.get_parameter(weight_key) + _, weight_new_shape = shard_weight_tensor( gm=gm, weight_tensor=original_weight, @@ -1892,6 +1893,8 @@ def _process_ssm_sharding( if "out_proj" not in str(n) ] for weight_node in weight_nodes: + # if is_any_ssm_op(list(weight_node.users)[0]): + # continue weight_key = weight_node.target # Get the weight parameter try: diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 286650c77fd..da39d59f26e 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -131,6 +131,8 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node): def extract_weight_node(node: Node) -> int: """Extracts the weight node from the given parametrized node""" + gm = node.graph.owning_module + param_names = {name for name, _ in gm.named_parameters()} def find_get_attr_node(weight_node: Node) -> Node: """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op.""" @@ -141,7 +143,7 @@ def find_get_attr_node(weight_node: Node) -> Node: torch.ops.aten.view.default, } - if weight_node.op == "get_attr": + if weight_node.op == "get_attr" and weight_node.target in param_names: return weight_node # If node is not in the list of allowable ops then return None @@ -325,6 +327,15 @@ def is_any_ssm_op(node: Node) -> bool: ) +def is_any_conv_op(node: Node) -> bool: + return is_op( + node, + ops=[ + torch.ops.auto_deploy.torch_causal_conv1d, + ], + ) + + def is_any_attention_op(node: Node) -> bool: return is_op( node, From 7fa4454b32c50b2b88ed3adb8bad39f89e565291 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Fri, 26 Dec 2025 03:25:33 -0800 Subject: [PATCH 07/18] wip Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../transform/library/quantization.py | 2 +- .../auto_deploy/transform/library/sharding.py | 118 +++++++++--------- .../_torch/auto_deploy/utils/node_utils.py | 91 +++++++++----- .../auto_deploy/utils/quantization_utils.py | 4 +- 4 files changed, 119 insertions(+), 96 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 21d1ccd2348..142ae02b019 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -140,7 +140,7 @@ def _insert_quantized_linear( The state_dict is also updated to contain the sharded weights. """ param_name, _ = extract_param_names_from_node(node) - original_weight = gm.get_parameter(param_name) + original_weight = gm.get_parameter(param_name[0]) new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False) modname, _, attrname = param_name.rpartition(".") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 64e25ab57f1..074eaebaaef 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -38,8 +38,7 @@ LayerSubgraph, LayerType, bfs, - extract_param_names_from_node, - extract_weight_node, + extract_weight_nodes, filtered_nodes, get_all_layer_subgraphs, get_layer_after_linear_node, @@ -48,7 +47,6 @@ is_any_moe_op, is_any_ssm_op, is_op, - num_users_of_weight_node, shape, subgraph, ) @@ -1330,68 +1328,70 @@ def _shard_parameter_node( rank, world_size = config.rank, config.world_size allreduce_strategy = config.allreduce_strategy.name - num_users = num_users_of_weight_node(node) - if num_users > 1 or num_users == 0: - ad_logger.warning( - f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping." - ) - return - # get weight and bias key - weight_key, bias_key = extract_param_names_from_node(node) - - modname = weight_key.rpartition(".")[0] - submod = gm.get_submodule(modname) - - # Shard weight using the unified function (also updates the parameter) - original_weight = gm.get_parameter(weight_key) - - _, weight_new_shape = shard_weight_tensor( - gm=gm, - weight_tensor=original_weight, - param_key=weight_key, - dim=dim, - rank=rank, - world_size=world_size, - min_local_shape=min_local_shape, - fused_weight_dims=fused_weight_dims, - ) - - if bias_key is not None and dim == 0: - # update bias for dim 0 --> we can handle it like the weight - original_bias = gm.get_parameter(bias_key) - shard_weight_tensor( + # num_users = num_users_of_weight_node(node) + # if num_users > 1 or num_users == 0: + # ad_logger.warning( + # f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping." + # ) + # return + # # get weight and bias key + # weight_key, bias_key = extract_param_names_from_node(node) + + # modname = weight_key.rpartition(".")[0] + # submod = gm.get_submodule(modname) + + # # Shard weight using the unified function (also updates the parameter) + # original_weight = gm.get_parameter(weight_key) + weight_nodes = extract_weight_nodes(node) + for weight_node, bias_node in weight_nodes: + _, weight_new_shape = shard_weight_tensor( gm=gm, - weight_tensor=original_bias, - param_key=bias_key, + weight_tensor=weight_node.node, + param_key=weight_node.node_key, dim=dim, rank=rank, world_size=world_size, min_local_shape=min_local_shape, fused_weight_dims=fused_weight_dims, ) - elif bias_key is not None and rank != world_size - 1: - # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid - # double counting it. For all other we will delete the bias. - args = list(node.args) - node_bias = args[2] - args[2] = None - node.args = tuple(args) - gm.graph.erase_node(node_bias) - bias_param_name = bias_key.rpartition(".")[-1] - setattr(submod, bias_param_name, None) - gm._register_load_state_dict_pre_hook(partial(_load_hook_remove, param_key=bias_key)) - - if quantization_cb is not None: - quantization_cb( - gm=gm, - submod=submod, - node=node, - weight_key=weight_key, - weight_new_shape=weight_new_shape, - dim=dim, - rank=rank, - world_size=world_size, - ) + + if bias_node is not None and dim == 0: + # update bias for dim 0 --> we can handle it like the weight + shard_weight_tensor( + gm=gm, + weight_tensor=bias_node.node, + param_key=bias_node.node_key, + dim=dim, + rank=rank, + world_size=world_size, + min_local_shape=min_local_shape, + fused_weight_dims=fused_weight_dims, + ) + elif bias_node is not None and rank != world_size - 1: + # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid + # double counting it. For all other we will delete the bias. + args = list(node.args) + node_bias = args[2] + args[2] = None + node.args = tuple(args) + gm.graph.erase_node(node_bias) + bias_param_name = bias_node.node_key.rpartition(".")[-1] + setattr(bias_node.submod, bias_param_name, None) + gm._register_load_state_dict_pre_hook( + partial(_load_hook_remove, param_key=bias_node.node_key) + ) + + if quantization_cb is not None: + quantization_cb( + gm=gm, + submod=weight_node.submod, + node=node, + weight_key=weight_node.node_key, + weight_new_shape=weight_new_shape, + dim=dim, + rank=rank, + world_size=world_size, + ) # # # column shard with no gather: the output is sharded if not add_dist: @@ -2253,7 +2253,7 @@ def detect_sharding_from_config( for lin_node in linear_nodes: # use node's weight name to get the module name - module_name = extract_weight_node(lin_node).target + module_name = extract_weight_nodes(lin_node)[0].target if any(attn_name in module_name for attn_name in attn_names): # find the next attention node and infer the head_dim diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index da39d59f26e..08f7f946c39 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -8,6 +8,7 @@ import torch from pydantic import BaseModel, ConfigDict +from torch import nn from torch._ops import OpOverload, OpOverloadPacket from torch.fx import GraphModule, Node @@ -51,6 +52,13 @@ class LayerSubgraph(BaseModel): min_local_shape: int = 1 +class WeightNode(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + node: Node + node_key: str + submod: nn.Module + + @dataclass class modelopt_quant_params: input_node: torch.fx.node.Node = None @@ -129,10 +137,12 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node): return input_params, weight_params, output_params -def extract_weight_node(node: Node) -> int: - """Extracts the weight node from the given parametrized node""" +def extract_weight_nodes(node: Node) -> Tuple[List[WeightNode], List[WeightNode]]: + """Extracts the list of weight node and optional bias node from the given parametrized node""" gm = node.graph.owning_module - param_names = {name for name, _ in gm.named_parameters()} + param_names = {name for name, _ in gm.named_parameters()}.union( + {name for name, _ in gm.named_buffers()} + ) def find_get_attr_node(weight_node: Node) -> Node: """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op.""" @@ -157,55 +167,68 @@ def find_get_attr_node(weight_node: Node) -> Node: return None if is_op(node, torch.ops.aten.bmm): - weight_node = node.args[1] + # no bias for bmm + return [WeightNode(node=node.args[1], node_key=node.args[1].target)], [] # for other parametrized nodes, we need to find the weight node else: - weight_nodes = [ + all_weight_nodes = [ n for n in node.args if isinstance(n, Node) and find_get_attr_node(n) is not None ] - # can be two weights (if bias weight is present) - weight_node = None - if weight_nodes: - weight_node = weight_nodes[0] - # for modelopt quantized graph, there will be a quantize_op - _, weight_params, _ = get_quantization_params_from_linear_node(node) - weight_node = weight_params.input_node if weight_params else weight_node - assert weight_node is not None, "Expected at least one weight node in the parametrized node" - return find_get_attr_node(weight_node) + # separate weight nodes and bias nodes + weight_nodes = [n for n in all_weight_nodes if n.target.endswith("weight")] + bias_nodes = [n for n in all_weight_nodes if n.target.endswith("bias")] + weight_nodes = [ + WeightNode( + node=n, node_key=n.target, submod=gm.get_submodule(n.target.rpartition(".")[0]) + ) + for n in weight_nodes + ] + bias_nodes = [ + WeightNode( + node=n, node_key=n.target, submod=gm.get_submodule(n.target.rpartition(".")[0]) + ) + for n in bias_nodes + ] + return weight_nodes, bias_nodes def num_users_of_weight_node(node: Node) -> int: """Returns the number of users of the weight node of the given parametrized node.""" - weight_node = extract_weight_node(node) + weight_node = extract_weight_nodes(node)[0] return len(weight_node.users) if weight_node is not None else 0 -def extract_param_names_from_node(node: Node) -> Tuple[str, Optional[str]]: +def extract_param_names_from_node(node: Node) -> Tuple[List[str], Optional[List[str]]]: """Extracts the name of the parameter associated with the given parametrized node. Args: node: node with weight parameters in the graph. """ - weight_node = extract_weight_node(node) + # try: - assert weight_node, "Cannot identify weight parameter of linear node." + # except: + # a = 1 - # Map arg to named parameter - weight_name = weight_node.target + # assert weight_node, "Cannot identify weight parameter of linear node." - # check for bias - if is_op(node, torch.ops.aten.bmm): - bias_node = node.args[2] if len(node.args) > 2 else None - else: - weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"] - if len(weight_nodes) > 1: - bias_node = weight_nodes[1] - else: - bias_node = None - assert bias_node is None or bias_node.op == "get_attr" - bias_name = bias_node.target if bias_node is not None else None + # # Map arg to named parameter + # weight_name = weight_node.target + + # # check for bias + # if is_op(node, torch.ops.aten.bmm): + # bias_node = node.args[2] if len(node.args) > 2 else None + # else: + # weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"] + # if len(weight_nodes) > 1: + # bias_node = weight_nodes[1] + # else: + # bias_node = None + # assert bias_node is None or bias_node.op == "get_attr" + # bias_name = bias_node.target if bias_node is not None else None - return weight_name, bias_name + # return weight_name, bias_name + weight_nodes, bias_nodes = extract_weight_nodes(node) + return [n.node_key for n in weight_nodes], [n.node_key for n in bias_nodes] def get_op_overload_packet(node: Union[OpOverloadPacket, OpOverload]) -> OpOverloadPacket: @@ -751,9 +774,9 @@ def get_weight_shape( if not is_any_lin_op(node): return None if dim is None: - return shape(extract_weight_node(node)) + return shape(extract_weight_nodes(node)[0]) else: - return shape(extract_weight_node(node))[dim] + return shape(extract_weight_nodes(node)[0])[dim] def get_layer_after_linear_node( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py index aee98c37713..1db0c2ced2b 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py @@ -117,8 +117,8 @@ def should_skip_quantization( else: if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)): return True - param_name, _ = extract_param_names_from_node(node_or_name) - modname, _, _ = param_name.rpartition(".") + param_names, _ = extract_param_names_from_node(node_or_name) + modname, _, _ = param_names[0].rpartition(".") return any(fnmatch(modname, pattern) for pattern in excluded_patterns) From 549a6f8bc4deed0f3ddae118c64f6053bba5477b Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 28 Dec 2025 10:18:48 -0800 Subject: [PATCH 08/18] working SSM sharding Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../transform/library/quantization.py | 21 +-- .../auto_deploy/transform/library/sharding.py | 146 +++--------------- .../_torch/auto_deploy/utils/node_utils.py | 122 ++++++++++----- .../auto_deploy/utils/quantization_utils.py | 7 +- 4 files changed, 123 insertions(+), 173 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 142ae02b019..292de523394 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -17,7 +17,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import ( - extract_param_names_from_node, + extract_weight_nodes, get_quantization_params_from_linear_node, is_bmm_op, is_linear_op, @@ -139,13 +139,12 @@ def _insert_quantized_linear( The state_dict is also updated to contain the sharded weights. """ - param_name, _ = extract_param_names_from_node(node) - original_weight = gm.get_parameter(param_name[0]) - new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False) - modname, _, attrname = param_name.rpartition(".") + weight_nodes = extract_weight_nodes(node) + lin_weight = weight_nodes.weights[0] + new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False) + modname, _, attrname = lin_weight.node_key.rpartition(".") - submod = gm.get_submodule(modname) - setattr(submod, attrname, new_param) + setattr(lin_weight.submod, attrname, new_param) # check modelopt quantizers from graph if is_quantized_graph: @@ -171,10 +170,12 @@ def _insert_quantized_linear( ) # Note: canonicalize_graph() will remove input/weight/output quantizer - for scale_name, scale in self.default_scales(original_weight.shape).items(): - submod.register_buffer(scale_name, scale) + for scale_name, scale in self.default_scales(lin_weight.tensor.shape).items(): + lin_weight.submod.register_buffer(scale_name, scale) - gm._register_load_state_dict_pre_hook(partial(self.load_hook, weight_name=param_name)) + gm._register_load_state_dict_pre_hook( + partial(self.load_hook, weight_name=lin_weight.node_key) + ) with gm.graph.inserting_before(node): scales = {} diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 074eaebaaef..bfc286e2904 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -38,6 +38,7 @@ LayerSubgraph, LayerType, bfs, + extract_weight_name, extract_weight_nodes, filtered_nodes, get_all_layer_subgraphs, @@ -1272,10 +1273,6 @@ def split_fused_tensor( fused_dims: list = fused_weight_dims, d: int = dim, ) -> torch.Tensor: - # dim_d = t.shape[d] - # num_parts = 1 - # part_size = dim_d // num_parts - # fused_dims = [part_size] * num_parts return torch.cat( [split_tensor(w) for w in torch.split(t, fused_dims, dim=d)], dim=d, @@ -1343,10 +1340,10 @@ def _shard_parameter_node( # # Shard weight using the unified function (also updates the parameter) # original_weight = gm.get_parameter(weight_key) weight_nodes = extract_weight_nodes(node) - for weight_node, bias_node in weight_nodes: + for weight_node in weight_nodes.weights: _, weight_new_shape = shard_weight_tensor( gm=gm, - weight_tensor=weight_node.node, + weight_tensor=weight_node.tensor, param_key=weight_node.node_key, dim=dim, rank=rank, @@ -1354,12 +1351,24 @@ def _shard_parameter_node( min_local_shape=min_local_shape, fused_weight_dims=fused_weight_dims, ) + if quantization_cb is not None: + quantization_cb( + gm=gm, + submod=weight_node.submod, + node=node, + weight_key=weight_node.node_key, + weight_new_shape=weight_new_shape, + dim=dim, + rank=rank, + world_size=world_size, + ) - if bias_node is not None and dim == 0: + for bias_node in weight_nodes.biases: + if dim == 0: # update bias for dim 0 --> we can handle it like the weight shard_weight_tensor( gm=gm, - weight_tensor=bias_node.node, + weight_tensor=bias_node.tensor, param_key=bias_node.node_key, dim=dim, rank=rank, @@ -1381,18 +1390,6 @@ def _shard_parameter_node( partial(_load_hook_remove, param_key=bias_node.node_key) ) - if quantization_cb is not None: - quantization_cb( - gm=gm, - submod=weight_node.submod, - node=node, - weight_key=weight_node.node_key, - weight_new_shape=weight_new_shape, - dim=dim, - rank=rank, - world_size=world_size, - ) - # # # column shard with no gather: the output is sharded if not add_dist: return @@ -1423,107 +1420,6 @@ def _update_node_args(node: Node, args: tuple) -> None: ) -def _insert_sharded_moe_stacked( - gm: GraphModule, - node: Node, - rank: int, - world_size: int, - allreduce_strategy: AllReduceStrategy, - scale_names: Sequence[str] = (), -): - """Update the torch_moe node with sliced stacked weight tensors, - sharded `selected_experts` and `final_scales(router_logics)`. - Add an all_reduce node after the moe node. - - For torch_moe with stacked tensor format (single-element lists containing 3D tensors). - - NOTE: allreduce_strategy is MANDATORY and must be explicitly provided. - """ - if allreduce_strategy is None: - raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}") - - # Extract the stacked tensors from single-element lists - # args[3] = w1_weight (Node representing list with one 3D tensor, or direct list) - # args[4] = w2_weight (Node representing list with one 3D tensor, or direct list) - - # Helper to extract tensor node from list (handles both Node and direct list) - def extract_tensor_from_list_arg(list_arg): - if isinstance(list_arg, Node) and list_arg.target is list: - # It's a list() call node - extract from its args - return list_arg.args[0][0] # args[0] is the list content, [0] is first element - elif isinstance(list_arg, (list, tuple)): - # Direct list - return list_arg[0] - else: - raise ValueError(f"Unexpected list format: {type(list_arg)}") - - w3_w1_tensor_node = extract_tensor_from_list_arg(node.args[3]) - w2_tensor_node = extract_tensor_from_list_arg(node.args[4]) - num_experts = _get_dim0_from_arg(gm, w3_w1_tensor_node) - - args = list(node.args) - - # -- Handle selected_experts and final_scales sharding -- - selected_experts = args[1] - final_scales = args[2] - - experts_per_rank = num_experts // world_size - - with gm.graph.inserting_before(node): - lower = experts_per_rank * rank - # selected_experts_local = selected_experts - low - selected_experts_local = gm.graph.create_node( - "call_function", operator.sub, args=(selected_experts, lower), kwargs={} - ) - - # For num_experts % world_size != 0 case, - # assign the last (num_experts % world_size) experts to the last rank - div_node = gm.graph.create_node( - "call_function", operator.floordiv, args=(selected_experts, experts_per_rank), kwargs={} - ) - - comp_op = torch.ge if rank == world_size - 1 else torch.eq - rank_mask = gm.graph.create_node("call_function", comp_op, args=(div_node, rank), kwargs={}) - - # final_scales_local = final_scales * rank_mask - final_scales_local = gm.graph.create_node( - "call_function", operator.mul, args=(final_scales, rank_mask), kwargs={} - ) - - # -- Transform expert weight parameters -- - local_lo, local_hi = _split_range_last_remainder(num_experts, world_size, rank) - - # Transform w3_w1_stacked: slice experts, swap [W1,W3]->[W3,W1], transpose (E,H,2I)->(E,2I,H) - if isinstance(w3_w1_tensor_node, Node): - _transform_bmm_moe_weight_param( - gm, w3_w1_tensor_node, local_lo, local_hi, swap_gate_up=True - ) - - # Transform w2_stacked: slice experts, transpose (E,I,H)->(E,H,I) - if isinstance(w2_tensor_node, Node): - _transform_bmm_moe_weight_param(gm, w2_tensor_node, local_lo, local_hi, swap_gate_up=False) - - # -- Update args (keep same lists/nodes, just with transformed parameters) -- - args[1] = selected_experts_local - args[2] = final_scales_local - # args[3] and args[4] stay the same - we modified the parameters in-place - - ad_logger.debug( - f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}." - ) - - node.args = tuple(args) - - # -- add an all_reduce node -- - with gm.graph.inserting_after(node): - dist_node = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_reduce.default, - args=(node, allreduce_strategy), - ) - node.replace_all_uses_with(dist_node) - dist_node.replace_input_with(dist_node, node) - - def _insert_sharded_moe( gm: GraphModule, node: Node, @@ -2253,9 +2149,9 @@ def detect_sharding_from_config( for lin_node in linear_nodes: # use node's weight name to get the module name - module_name = extract_weight_nodes(lin_node)[0].target + weight_name = extract_weight_name(lin_node) - if any(attn_name in module_name for attn_name in attn_names): + if any(attn_name in weight_name for attn_name in attn_names): # find the next attention node and infer the head_dim next_attention_node, _ = bfs( lin_node, is_any_attention_op, attr_next="users", include_root=False @@ -2279,7 +2175,7 @@ def detect_sharding_from_config( # Then we escape dots, and finally we replace @ with .* pattern_string = pattern_string.replace("*", "@") pattern_regex = re.escape(pattern_string).replace("@", ".*") - if re.match(pattern_regex, module_name): + if re.match(pattern_regex, weight_name): # we have a match. Get the config for this layer config = tp_plan[key] @@ -2318,7 +2214,7 @@ def detect_sharding_from_config( elif "local" in config: # Check if this applies to shared experts in EP parallelism. # If yes, apply the TP col-row shard. - if "shared" in module_name: + if "shared" in weight_name: col_row_action = config.replace("local_", "") if col_row_action == "colwise": transform_container.add( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 08f7f946c39..d9c0d6f28b0 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -55,10 +55,56 @@ class LayerSubgraph(BaseModel): class WeightNode(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) node: Node + tensor: torch.Tensor node_key: str submod: nn.Module +class WeightNodes(BaseModel): + weights: list[WeightNode] + biases: list[WeightNode] + + +class ModuleParams: + """Static class for caching module parameters and buffers to avoid repeated lookups.""" + + _parameters: dict = {} + _buffers: dict = {} + + @classmethod + def get_all_params(cls, gm: GraphModule) -> set: + """Get all parameter and buffer names for a GraphModule.""" + if gm not in cls._parameters: + cls._set_params(gm) + return cls._parameters[gm].union(cls._buffers[gm]) + + @classmethod + def get_buffers(cls, gm: GraphModule) -> set: + """Get all buffer names for a GraphModule.""" + if gm not in cls._buffers: + cls._set_params(gm) + return cls._buffers[gm] + + @classmethod + def get_parameters(cls, gm: GraphModule) -> set: + """Get all parameter names for a GraphModule.""" + if gm not in cls._parameters: + cls._set_params(gm) + return cls._parameters[gm] + + @classmethod + def _set_params(cls, gm: GraphModule): + """Cache parameter and buffer names for a GraphModule.""" + cls._parameters[gm] = {name for name, _ in gm.named_parameters()} + cls._buffers[gm] = {name for name, _ in gm.named_buffers()} + + @classmethod + def clear_cache(cls): + """Clear the cached parameters and buffers.""" + cls._parameters.clear() + cls._buffers.clear() + + @dataclass class modelopt_quant_params: input_node: torch.fx.node.Node = None @@ -137,12 +183,24 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node): return input_params, weight_params, output_params -def extract_weight_nodes(node: Node) -> Tuple[List[WeightNode], List[WeightNode]]: +def extract_weight_name(node: Node) -> str: + weight_nodes = extract_weight_nodes(node) + return weight_nodes.weights[0].node_key + + +def get_const_tensor(tensor_name: str, gm: GraphModule) -> torch.Tensor: + if tensor_name in ModuleParams.get_parameters(gm): + return gm.get_parameter(tensor_name) + elif tensor_name in ModuleParams.get_buffers(gm): + return gm.get_buffer(tensor_name) + else: + raise ValueError(f"Tensor {tensor_name} not found in the graph") + + +def extract_weight_nodes(node: Node) -> WeightNodes: """Extracts the list of weight node and optional bias node from the given parametrized node""" gm = node.graph.owning_module - param_names = {name for name, _ in gm.named_parameters()}.union( - {name for name, _ in gm.named_buffers()} - ) + param_names = ModuleParams.get_all_params(gm) def find_get_attr_node(weight_node: Node) -> Node: """Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op.""" @@ -168,28 +226,45 @@ def find_get_attr_node(weight_node: Node) -> Node: if is_op(node, torch.ops.aten.bmm): # no bias for bmm - return [WeightNode(node=node.args[1], node_key=node.args[1].target)], [] + return WeightNodes( + [ + WeightNode( + node=node.args[1], + node_key=node.args[1].target, + tensor=gm.get_parameter(node.args[1].target), + ) + ], + [], + ) # for other parametrized nodes, we need to find the weight node else: all_weight_nodes = [ - n for n in node.args if isinstance(n, Node) and find_get_attr_node(n) is not None + find_get_attr_node(n) + for n in node.args + if isinstance(n, Node) and find_get_attr_node(n) is not None ] # separate weight nodes and bias nodes - weight_nodes = [n for n in all_weight_nodes if n.target.endswith("weight")] bias_nodes = [n for n in all_weight_nodes if n.target.endswith("bias")] + weight_nodes = [n for n in all_weight_nodes if n not in bias_nodes] weight_nodes = [ WeightNode( - node=n, node_key=n.target, submod=gm.get_submodule(n.target.rpartition(".")[0]) + node=n, + node_key=n.target, + submod=gm.get_submodule(n.target.rpartition(".")[0]), + tensor=get_const_tensor(n.target, gm), ) for n in weight_nodes ] bias_nodes = [ WeightNode( - node=n, node_key=n.target, submod=gm.get_submodule(n.target.rpartition(".")[0]) + node=n, + node_key=n.target, + submod=gm.get_submodule(n.target.rpartition(".")[0]), + tensor=get_const_tensor(n.target, gm), ) for n in bias_nodes ] - return weight_nodes, bias_nodes + return WeightNodes(weights=weight_nodes, biases=bias_nodes) def num_users_of_weight_node(node: Node) -> int: @@ -204,29 +279,6 @@ def extract_param_names_from_node(node: Node) -> Tuple[List[str], Optional[List[ Args: node: node with weight parameters in the graph. """ - # try: - - # except: - # a = 1 - - # assert weight_node, "Cannot identify weight parameter of linear node." - - # # Map arg to named parameter - # weight_name = weight_node.target - - # # check for bias - # if is_op(node, torch.ops.aten.bmm): - # bias_node = node.args[2] if len(node.args) > 2 else None - # else: - # weight_nodes = [n for n in node.args if isinstance(n, Node) and n.op == "get_attr"] - # if len(weight_nodes) > 1: - # bias_node = weight_nodes[1] - # else: - # bias_node = None - # assert bias_node is None or bias_node.op == "get_attr" - # bias_name = bias_node.target if bias_node is not None else None - - # return weight_name, bias_name weight_nodes, bias_nodes = extract_weight_nodes(node) return [n.node_key for n in weight_nodes], [n.node_key for n in bias_nodes] @@ -774,9 +826,9 @@ def get_weight_shape( if not is_any_lin_op(node): return None if dim is None: - return shape(extract_weight_nodes(node)[0]) + return shape(extract_weight_nodes(node).weights[0].node) else: - return shape(extract_weight_nodes(node)[0])[dim] + return shape(extract_weight_nodes(node).weights[0].node)[dim] def get_layer_after_linear_node( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py index 1db0c2ced2b..ed8c4cdfaa6 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py @@ -8,7 +8,7 @@ from ..custom_ops.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX from .logger import ad_logger from .node_utils import ( - extract_param_names_from_node, + extract_weight_name, get_quantization_params_from_linear_node, is_bmm_op, is_linear_op, @@ -117,8 +117,9 @@ def should_skip_quantization( else: if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)): return True - param_names, _ = extract_param_names_from_node(node_or_name) - modname, _, _ = param_names[0].rpartition(".") + # param_names, _ = extract_param_names_from_node(node_or_name) + weight_name = extract_weight_name(node_or_name) + modname = weight_name.rpartition(".")[0] return any(fnmatch(modname, pattern) for pattern in excluded_patterns) From 0063516a64535f53bbd392b79de71857f9edbf07 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 28 Dec 2025 13:51:05 -0800 Subject: [PATCH 09/18] cleanup in progress Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py | 6 +++--- .../_torch/auto_deploy/transform/library/sharding.py | 3 +-- tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 6 ++++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py index 73191ede538..ac129f2d9f1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py @@ -13,7 +13,7 @@ from ...shim.interface import CachedSequenceInterface from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.logger import ad_logger -from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op +from ...utils.node_utils import extract_weight_name, is_linear_op, is_op from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node y2 = y[:, out1:out1+out2] """ # some info we need - keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes] + keys_unfused = [extract_weight_name(n) for n in linear_nodes] params_unfused = [gm.get_parameter(k) for k in keys_unfused] sizes_unfused = [p.size(0) for p in params_unfused] key_fused = f"fused_weight_{idx}" @@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple def _insert_fused_quant_gemm( self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node] ): - keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes] + keys_unfused = [extract_weight_name(n) for n in linear_nodes] params_unfused = [gm.get_parameter(k) for k in keys_unfused] sizes_unfused = [p.size(0) for p in params_unfused] key_fused = f"fused_weight_{idx}" diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index bfc286e2904..123b291c06f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -2338,7 +2338,6 @@ def detect_column_row_shard( min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism splitting, e.g., the individual heads into smaller shards. """ - # test_moe_variants() ad_logger.debug("Before sharding graph: " + str(gm)) config = transform_container.config world_size = config.world_size @@ -2443,7 +2442,7 @@ def detect_column_row_shard( # simple shard remaining linear nodes if config.shard_all_unprocessed: num_simple_shards += _process_simple_shard(unprocessed_linear_nodes, transform_container) - num_column_row_shards += num_ssm_shards + num_column_row_shards += num_ssm_shards + num_mla_shards num_shards = num_simple_shards + num_column_row_shards ad_logger.info( f"Heuristics found {num_shards} TP shards. Simple: {num_simple_shards}, " diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index d9c0d6f28b0..f2278c1ac25 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -575,8 +575,10 @@ def get_all_layer_subgraphs(gm: GraphModule) -> List[List[Node]]: # closing is the last linear node in the layer layer_subgraph = get_layer_after_linear_node(linear_nodes, terminating_indices) if layer_subgraph.opening_nodes is not None and len(layer_subgraph.opening_nodes) > 0: - unprocessed_linear_nodes -= set(layer_subgraph.opening_nodes) | set( - [layer_subgraph.terminating_node] + unprocessed_linear_nodes -= ( + set(layer_subgraph.opening_nodes) + | set([layer_subgraph.terminating_node]) + | set(layer_subgraph.subgraph_nodes) ) layer_subgraphs.append(layer_subgraph) last_lin_index = terminating_indices[-1] + 1 From 6a9ceffbd6f73ba2b0b53e0142367c3d09a77c6b Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 28 Dec 2025 13:58:02 -0800 Subject: [PATCH 10/18] code cleanup Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/utils/_graph.py | 4 ++-- .../_torch/auto_deploy/utils/node_utils.py | 16 +++------------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index cd61bd52f1e..d432a5a78f9 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -354,7 +354,7 @@ def get_input_embeddings(model: nn.Module) -> torch.Tensor: op="call_function", target=torch.ops.aten.embedding.default ) for node in found_nodes: - embedding_weights.append(get_weight_tensor(gm, node)) + embedding_weights.append(get_weight_tensor(node)) if hasattr(model, "get_input_embeddings"): embedding_weights.append(model.get_input_embeddings()) @@ -400,4 +400,4 @@ def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Nod def get_lm_head_weights(model: nn.Module) -> torch.Tensor: gm, output_node = get_output_node(model) lm_head_node = get_lm_head_node(gm, output_node) - return get_weight_tensor(gm, lm_head_node) + return get_weight_tensor(lm_head_node) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index f2278c1ac25..9e0031987ad 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -273,16 +273,6 @@ def num_users_of_weight_node(node: Node) -> int: return len(weight_node.users) if weight_node is not None else 0 -def extract_param_names_from_node(node: Node) -> Tuple[List[str], Optional[List[str]]]: - """Extracts the name of the parameter associated with the given parametrized node. - - Args: - node: node with weight parameters in the graph. - """ - weight_nodes, bias_nodes = extract_weight_nodes(node) - return [n.node_key for n in weight_nodes], [n.node_key for n in bias_nodes] - - def get_op_overload_packet(node: Union[OpOverloadPacket, OpOverload]) -> OpOverloadPacket: """Get the overload packet from the op overload.""" if isinstance(node, OpOverloadPacket): @@ -1011,10 +1001,10 @@ def shape(node: Node) -> Tuple[int, ...]: return node.meta["val"].shape -def get_weight_tensor(gm: GraphModule, node: Node) -> "torch.Tensor": +def get_weight_tensor(node: Node) -> torch.Tensor: """Extract the weight tensor from a node within a GraphModule.""" - weight_name = extract_param_names_from_node(node)[0] - return gm.get_parameter(weight_name) + weight_nodes = extract_weight_nodes(node) + return weight_nodes.weights[0].tensor def draw_graph(gm: GraphModule, filename: str): From 0e76f901c72409f7367df31ad32995438020ae4b Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 28 Dec 2025 14:00:38 -0800 Subject: [PATCH 11/18] code cleanup Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/transform/library/sharding.py | 11 ----------- tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 2 +- .../_torch/auto_deploy/utils/quantization_utils.py | 1 - 3 files changed, 1 insertion(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 123b291c06f..2358368e533 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -1325,17 +1325,6 @@ def _shard_parameter_node( rank, world_size = config.rank, config.world_size allreduce_strategy = config.allreduce_strategy.name - # num_users = num_users_of_weight_node(node) - # if num_users > 1 or num_users == 0: - # ad_logger.warning( - # f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping." - # ) - # return - # # get weight and bias key - # weight_key, bias_key = extract_param_names_from_node(node) - - # modname = weight_key.rpartition(".")[0] - # submod = gm.get_submodule(modname) # # Shard weight using the unified function (also updates the parameter) # original_weight = gm.get_parameter(weight_key) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 9e0031987ad..08e6e85b6c9 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -269,7 +269,7 @@ def find_get_attr_node(weight_node: Node) -> Node: def num_users_of_weight_node(node: Node) -> int: """Returns the number of users of the weight node of the given parametrized node.""" - weight_node = extract_weight_nodes(node)[0] + weight_node = extract_weight_nodes(node).weights[0].node return len(weight_node.users) if weight_node is not None else 0 diff --git a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py index ed8c4cdfaa6..889b06edb06 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py @@ -117,7 +117,6 @@ def should_skip_quantization( else: if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)): return True - # param_names, _ = extract_param_names_from_node(node_or_name) weight_name = extract_weight_name(node_or_name) modname = weight_name.rpartition(".")[0] From ae50fdf6b947da62187495e1309f749e746e06b3 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Sun, 28 Dec 2025 14:04:49 -0800 Subject: [PATCH 12/18] code cleanup Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../_torch/auto_deploy/transform/library/quantization.py | 1 + tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 292de523394..5b2902d6422 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -140,6 +140,7 @@ def _insert_quantized_linear( The state_dict is also updated to contain the sharded weights. """ weight_nodes = extract_weight_nodes(node) + assert len(weight_nodes.weights) == 1, "Expected exactly one weight node" lin_weight = weight_nodes.weights[0] new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False) modname, _, attrname = lin_weight.node_key.rpartition(".") diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 08e6e85b6c9..4412fa8afbf 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -231,7 +231,8 @@ def find_get_attr_node(weight_node: Node) -> Node: WeightNode( node=node.args[1], node_key=node.args[1].target, - tensor=gm.get_parameter(node.args[1].target), + tensor=get_const_tensor(node.args[1].target, gm), + submod=gm.get_submodule(node.args[1].target.rpartition(".")[0]), ) ], [], From a8df449b3d5715f816eafddef843207fe68398bf Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Tue, 30 Dec 2025 04:34:43 -0800 Subject: [PATCH 13/18] fixed BMM weight finding Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 4412fa8afbf..1f6c87fa40c 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -226,16 +226,17 @@ def find_get_attr_node(weight_node: Node) -> Node: if is_op(node, torch.ops.aten.bmm): # no bias for bmm + weight_node = find_get_attr_node(node.args[1]) return WeightNodes( - [ + weights=[ WeightNode( node=node.args[1], - node_key=node.args[1].target, - tensor=get_const_tensor(node.args[1].target, gm), - submod=gm.get_submodule(node.args[1].target.rpartition(".")[0]), + node_key=weight_node.target, + tensor=get_const_tensor(weight_node.target, gm), + submod=gm.get_submodule(weight_node.target.rpartition(".")[0]), ) ], - [], + biases=[], ) # for other parametrized nodes, we need to find the weight node else: From f3116e554ae2b439a2b739476e31382753f1902c Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Tue, 30 Dec 2025 09:46:47 -0800 Subject: [PATCH 14/18] load hook fix Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../models/custom/modeling_nemotron_h.py | 35 +++++++++++++------ .../auto_deploy/transform/library/sharding.py | 3 -- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py index 23c79c55e44..9869c2e8b40 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py @@ -128,8 +128,6 @@ def __init__(self, config, layer_idx: int): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias - self.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) - def torch_forward(self, input_states): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype @@ -191,14 +189,6 @@ def torch_forward(self, input_states): def forward(self, hidden_states): return self.torch_forward(hidden_states) - @staticmethod - def _load_state_dict_pre_hook(module, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) -> None: - A_log_key = prefix + "A_log" - A_minus_key = prefix + "A_minus" - if A_log_key in state_dict: - state_dict[A_minus_key] = -torch.exp(state_dict.pop(A_log_key).float()) - class NemotronHRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -592,6 +582,13 @@ def __init__(self, config): self.backbone = NemotronHModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Recursively iterate over all modules in self.backbone and list those with A_minus or A_log in their name + self.backbone_modules_with_A = [] + for module_name, module in self.backbone.named_modules(): + for param_name, _ in module.named_parameters(recurse=False): + if param_name in ("A_minus", "A_log"): + self.register_load_state_dict_pre_hook(self._a_log_pre_hook) + self.backbone_modules_with_A.append((module_name, param_name)) # Initialize weights and apply final processing self.post_init() @@ -622,5 +619,23 @@ def forward( return NemotronHCausalLMOutput(logits) + @staticmethod + def _a_log_pre_hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + all_keys = list(state_dict.keys()) + for key in all_keys: + if "A_log" in key: + A_log_key = key + A_minus_key = key.replace("A_log", "A_minus") + state_dict[A_minus_key] = -torch.exp(state_dict.pop(A_log_key).float()) + AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 2358368e533..c8d95f8233f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -896,13 +896,10 @@ def _load_hook( # This is quite a hacky solution. A better solution would be to store extra_state in # the state_dict to identify whether the state_dict is sharded or not. key = prefix + param_key - ad_logger.debug(f"Sharder LOAD hook is called for '{key}'") if key not in state_dict: return p_to_load = state_dict[key] - p_to_load = p_to_load if param_shape == p_to_load.shape else f_split(p_to_load) - state_dict[key] = p_to_load From 7c32860d37f90d4653285669e14c5d5895bfbf7a Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 31 Dec 2025 07:22:26 -0800 Subject: [PATCH 15/18] Updated sharding tests Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../unit/multigpu/transformations/library/test_tp_sharding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index df4a07e34ca..8472d272898 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -237,7 +237,7 @@ def _run_sharding_execution_job( ssm_state_size=16, # Scaled from 128 mamba_num_heads=num_heads, mamba_head_dim=num_features // num_heads, # 8 - n_groups=1, # Typical value + n_groups=num_heads, # Typical value chunk_size=256, conv_kernel=4, use_conv_bias=bias, @@ -388,7 +388,7 @@ def _run_pattern_detection_job( ssm_state_size=16, mamba_num_heads=num_heads, mamba_head_dim=num_features // num_heads, - n_groups=1, + n_groups=num_heads, chunk_size=256, conv_kernel=4, use_conv_bias=bias, From 23af0205f978885be37a7e1b0907342dc5d8bfef Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 31 Dec 2025 07:38:44 -0800 Subject: [PATCH 16/18] fixed Mamba tp sharding test Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../transformations/library/test_tp_sharding.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 8472d272898..e226d85c61c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -530,7 +530,13 @@ def _run_pattern_detection_job( else: dim = SplitDimension.COLUMN dist_op = None - fused_weight_dims = (num_features, num_features, 16, 16, num_heads) + fused_weight_dims = ( + num_features, + num_features, + 16 * num_heads, + 16 * num_heads, + num_heads, + ) expected_transformations.append( WeightShardingInfo( target_node=node.name, @@ -551,7 +557,7 @@ def _run_pattern_detection_job( dist_op=None, min_local_shape=1, layer_type=LayerType.SSM, - fused_weight_dims=(num_features, 16, 16), + fused_weight_dims=(num_features, 16 * num_heads, 16 * num_heads), ) ) if is_op(node, torch.ops.auto_deploy.torch_ssm): From f8dc0c2853c59409ec94abeea9e4bec4d2314b61 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Wed, 31 Dec 2025 10:12:31 -0800 Subject: [PATCH 17/18] fixed integration tests Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- tests/integration/test_lists/test-db/l0_h100.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index efa4e381291..f29c120bd2b 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -117,8 +117,7 @@ l0_h100: - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[True-1] - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[False] - accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True] - - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1] - - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[4] + - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8 - accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16 - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target] - examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[eagle3] From 5ac90447752d21a95051891371e83582da944675 Mon Sep 17 00:00:00 2001 From: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Date: Thu, 1 Jan 2026 06:47:22 -0800 Subject: [PATCH 18/18] Removed integration test update Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> --- .../integration/defs/accuracy/test_llm_api_autodeploy.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 851f5d628ef..c8adaa96849 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -218,19 +218,18 @@ def test_bf16(self): task.evaluate(llm) @pytest.mark.skip_less_device_memory(32000) - @pytest.mark.parametrize("world_size", [1, 4]) - def test_fp8(self, world_size): + def test_fp8(self): kwargs = self.get_default_kwargs() + sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_PATH_FP8, tokenizer=self.MODEL_PATH_FP8, - world_size=world_size, **kwargs) as llm: # Manually set quant_config for FP8 model to get the accuracy threshold llm.args.quant_config.quant_algo = QuantAlgo.FP8 llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8 - # task = MMLU(self.MODEL_NAME) - # task.evaluate(llm, sampling_params=sampling_params) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) task = GSM8K(self.MODEL_NAME) task.evaluate(llm)