Skip to content

[TRTLLM-8201][feat] TP sharding of Mamba layers#8688

Closed
greg-kwasniewski1 wants to merge 22 commits intoNVIDIA:mainfrom
nv-auto-deploy:gk/sharding_mamba_rebased
Closed

[TRTLLM-8201][feat] TP sharding of Mamba layers#8688
greg-kwasniewski1 wants to merge 22 commits intoNVIDIA:mainfrom
nv-auto-deploy:gk/sharding_mamba_rebased

Conversation

@greg-kwasniewski1
Copy link
Collaborator

@greg-kwasniewski1 greg-kwasniewski1 commented Oct 27, 2025

Supports head parallelism for Mamba layers.
Currently, it assumes a rigid module structure:

  • exactly one conv1d node
  • exactly one torch_ssm node
  • exactly two split operations (after in_proj and after conv1d)

When applied it:

  • shards fused weights for in_proj, conv1d, rmsnorm, out_proj
  • updates shape parameters for splits, views, convs, etc

Sharding mamba layers is currently supported both through sharding heuristics and through factory sharding, by specifying mamba for in_proj linear node.

Summary by CodeRabbit

  • New Features

    • Added support for SSM-based tensor parallelism sharding workflows
    • Introduced parameter update transforms for dynamic adjustment during model sharding
    • Enhanced sharding configuration with heuristic and factory-based sources
  • Bug Fixes

    • Fixed unpacking of return values in transform pipeline
    • Corrected parameter extraction logic for linear operations
  • Refactor

    • Migrated sharding infrastructure from TP-centric to flexible weight-sharding model
    • Generalized weight handling utilities for broader node support
    • Restructured model patching system for improved modularity

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@greg-kwasniewski1 greg-kwasniewski1 requested a review from a team as a code owner October 27, 2025 12:41
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 27, 2025

📝 Walkthrough

Walkthrough

Refactored sharding infrastructure from tensor-parallel (TP)-specific constructs to a generalized weight-based sharding system. Introduced new enums (ShardingSource, ShardingDim, LayerType) and types (WeightShardingInfo, ParameterUpdateInfo) for flexible transform composition. Added SSM-based sharding workflow scaffolding, enhanced node utility functions for parametrized node handling, and updated configuration to support heuristic-based sharding with expanded dimensions.

Changes

Cohort / File(s) Change Summary
Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Replaced TP-fixed sharding semantics with heuristic-based source; expanded sharding_dims to include ['ssm', 'tp', 'ep', 'bmm'] and enabled partial config support.
Model Patches
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Removed MOE forward implementation and patching; introduced _set_sharding_config_patched for head dimension and TP plan configuration; applied sharding config patch to factory.
Transform Interface
tensorrt_llm/_torch/auto_deploy/transform/interface.py
Added __add__ operator to TransformInfo for combining transform fields (logical AND for flags, sum for counts).
Transform Library—MOE
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
Adjusted BFS call sites to unpack tuple return values (node, extra) and discard auxiliary values.
Transform Library—Fusion/Quantization
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py, tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
Replaced calls to extract_param_names_from_lin_node with extract_param_names_from_node in fusion and quantization parameter extraction.
Transform Library—Sharding
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Major refactor: replaced TPShardingInfo with WeightShardingInfo, introduced ShardingSource and ShardingDim enums in config, added post-sharding parameter update phase, refactored _process_simple_shard to return count, added SSM sharding workflow (detect_ssm_shard, _process_ssm_sharding), expanded column-row sharding flow with weight discovery and view/reshape updates.
Node Utilities
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
Renamed functions to generalize from MM-specific to parametrized nodes: extract_weight_node, num_users_of_weight_node, extract_param_names_from_lin_nodeextract_param_names_from_node; enhanced bfs() to return (Node, int) tuple with include_root parameter; added identify_layer_subgraphs(), subgraph(), and draw_graph() utilities; improved predecessors() and successors() to track visited nodes.
Quantization Utilities
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
Updated import and call site: extract_param_names_from_lin_nodeextract_param_names_from_node.
Sharding Utilities
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Introduced new types: WeightShardingInfo, ParameterUpdateInfo, LayerType, ShardingSource, ShardingDim enums; added shard_weight_tensor() and get_all_weights_in_subgraph() public APIs; added _update_node_args() and _insert_sharded_mamba(); refactored _validate_sharded_shapes() (formerly _update_view_nodes); updated _resolve_tp_cls_from_node() to return WeightShardingInfo; extended ShardingConfig with new transform lists (weight_sharding_transforms, parameter_update_transforms) and add() method.
Tests
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
Replaced all TPShardingInfo usages with WeightShardingInfo; updated access from tp_transforms to weight_sharding_transforms.

Sequence Diagram

sequenceDiagram
    participant Config as ShardingConfig
    participant Factory as detect_sharding<br/>from_factory
    participant Heuristic as detect_sharding<br/>from_heuristic
    participant SSM as detect_ssm_shard
    participant Param as Parameter<br/>Updates
    participant Result as weight_sharding_<br/>transforms

    rect rgb(220, 240, 255)
    Note over Config,Result: New Sharding Flow (Post-Refactor)
    end

    Config->>Factory: sharding_source includes FACTORY
    Factory->>Result: populate WeightShardingInfo<br/>(colwise, rowwise, mamba, etc.)

    Config->>Heuristic: sharding_source includes HEURISTIC
    Heuristic->>Result: detect patterns<br/>populate WeightShardingInfo

    Config->>SSM: sharding_dims includes SSM
    SSM->>SSM: identify SSM regions<br/>extract subgraph
    SSM->>Result: create WeightShardingInfo<br/>for SSM patterns

    Factory->>Param: parameter_update_transforms phase
    Heuristic->>Param: parameter_update_transforms phase
    SSM->>Param: parameter_update_transforms phase
    
    Param->>Param: update split sizes,<br/>view/reshape shapes
    Param->>Result: append ParameterUpdateInfo
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Primary areas requiring attention:
    • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py — substantial refactor introducing SSM workflow, WeightShardingInfo-based flow, and parameter update phase; logic density and new flow paths require careful verification
    • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py — extensive additions of new classes (enums, types), public APIs for weight sharding, and fused-weight handling; verify public API contracts and backward compatibility
    • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py — generalization of MM-specific utilities to parametrized nodes; BFS signature change from returning single value to tuple; ensure all call sites correctly unpack and handle new return type
    • Cross-file consistency: verify all renamed functions (extract_param_names_from_lin_nodeextract_param_names_from_node, etc.) are consistently updated across transform libraries (fusion, quantization, MOE)
    • Configuration compatibility: validate that default.yaml changes align with new ShardingSource/ShardingDim enum values in code

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 56.92% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ⚠️ Warning The PR description is incomplete and does not adequately follow the required template. While the author provided an initial paragraph describing what the PR does (supporting tensor-parallel sharding of Mamba layers with assumptions about module structure), the two critical required sections are entirely unfilled: the "Description" section lacks explanation of the issue being solved and the solution approach, and the "Test Coverage" section is empty with no information about what tests validate these changes. Although the opening text provides some technical details, it does not substitute for the structured sections specified in the template that are essential for understanding the motivation and validation of the PR. The author should fill in the "Description" section with a clear explanation of the problem being solved and the solution approach, and should populate the "Test Coverage" section with a list of relevant tests that safeguard these changes. The opening paragraph about Mamba layer sharding is useful but should be expanded with context about why this feature is needed and should be paired with the test information to ensure reviewers understand the scope and validation of the changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "[TRTLLM-8201][feat] TP sharding of Mamba layers" follows the template format with a valid JIRA ticket, proper type designation ([feat]), and a clear summary. The title directly aligns with the primary objective of the changeset, which is to add tensor-parallel sharding support for Mamba layers. The title is concise, specific, and clearly communicates the main change without being vague or misleading.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (8)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)

457-462: Fix incorrect unpacking with .args[0] accessor.

Line 457 attempts to unpack the bfs return tuple while simultaneously calling .args[0] on it, which is incorrect. The bfs function returns (Node, int), so calling .args[0] on the tuple itself will raise an AttributeError.

Apply this diff to fix the issue:

-            selected_experts, _ = bfs(
+            selected_experts_node, _ = bfs(
                 common_ancessor2,
                 lambda node: is_op(node, torch.ops.aten.one_hot),
                 attr_next="all_input_nodes",
                 boundary=start_boundary,
-            ).args[0]
+            )
+            selected_experts = selected_experts_node.args[0]

Alternatively, if you prefer a more compact form:

-            selected_experts, _ = bfs(
+            selected_experts = bfs(
                 common_ancessor2,
                 lambda node: is_op(node, torch.ops.aten.one_hot),
                 attr_next="all_input_nodes",
                 boundary=start_boundary,
-            ).args[0]
+            )[0].args[0]
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025).

Per coding guidelines, prepend the NVIDIA Apache-2.0 header to all source files.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

[Based on coding guidelines]

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)

1-1: Missing NVIDIA Apache-2.0 header (2025).

Same header snippet as in the other file should be prepended. [As per coding guidelines]


683-701: ShardingDim membership checks use strings; should use Enum.

sharding_dims is a List[ShardingDim]. String checks will always be False.

-        if "ep" in sharding_config.sharding_dims:
+        if ShardingDim.EP in sharding_config.sharding_dims:
@@
-        if "bmm" in sharding_config.sharding_dims:
+        if ShardingDim.BMM in sharding_config.sharding_dims:

575-610: Factory config: add support for “mamba” and prefer add() to avoid duplicates.

  • Validation currently rejects “mamba” (see sharding_utils.py).
  • Use sharding_config.add(...) consistently.
-                    sharding_config.weight_sharding_transforms.append(
-                        WeightShardingInfo.from_node(
+                    sharding_config.add(
+                        WeightShardingInfo.from_node(
                             lin_node,
                             split_dim=SplitDimension.COLUMN,
                             rank=rank,
                             world_size=world_size,
                             dist_op=None,
                             min_local_shape=min_local_shape,
                             layer_type=LayerType.MAMBA,
                         )
                     )
-                    num_row_col_shards += 1
+                    num_row_col_shards += 1

And ensure “mamba” is allowed in supported_modes in ShardingConfig.validate_config. See separate comment in sharding_utils.py.

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)

1-1: Missing NVIDIA Apache-2.0 header (2025).

Prepend the standard header as in other files. [As per coding guidelines]


829-837: Return type annotation uses Python 3.9+ generic.

Use typing.Type for Python 3.8.

-def _resolve_tp_cls_from_node(node: Node):
+def _resolve_tp_cls_from_node(node: Node) -> Type[WeightShardingInfo]:

1176-1184: Use typing.Type for Python 3.8.

-def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]:
+def _resolve_ep_cls_from_node(node: Node) -> Type[EPShardingInfo]:
🧹 Nitpick comments (5)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)

176-183: Consider documenting the distinction between __add__ and __and__.

The new __add__ operator has identical semantics to __and__ (both use logical AND for boolean fields). While this may be intentional for semantic clarity—where + suggests sequential accumulation and & suggests logical conjunction—having two operators with identical behavior could be confusing.

Consider one of the following:

  1. Add a docstring explaining when to use + vs &:
 def __add__(self, other: "TransformInfo") -> "TransformInfo":
+    """Combine transform infos sequentially (same semantics as __and__).
+    
+    Use this operator when accumulating results from sequential transforms.
+    Use __and__ (&) when combining parallel/independent transform results.
+    """
     return TransformInfo(
         skipped=self.skipped and other.skipped,
         num_matches=self.num_matches + other.num_matches,
         is_clean=self.is_clean and other.is_clean,
         has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes,
     )
  1. Or differentiate the semantics if sequential composition should be more optimistic:
def __add__(self, other: "TransformInfo") -> "TransformInfo":
    """Accumulate results sequentially (more optimistic than __and__)."""
    return TransformInfo(
        skipped=self.skipped and other.skipped,
        num_matches=self.num_matches + other.num_matches,
        is_clean=self.is_clean or other.is_clean,  # optimistic
        has_valid_shapes=self.has_valid_shapes or other.has_valid_shapes,  # optimistic
    )
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)

102-104: Prefer deterministic/efficient user selection.

Avoid list(...) slice; use next(iter(...)) and optionally guard type.

-        output_params = modelopt_quant_params.get_quant_params_from_quantize_node(
-            list(linear_op.users.keys())[0]
-        )
+        output_user = next(iter(linear_op.users)) if linear_op.users else None
+        output_params = modelopt_quant_params.get_quant_params_from_quantize_node(output_user)

[Based on static analysis hint RUF015]

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

381-382: Prefer deterministic user pick for weight users.

Use next(iter(...)) to avoid list(...) and follow Ruff guidance.

-                list(weight_node.users)[0],
+                next(iter(weight_node.users)),

[Based on static analysis hint RUF015]

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

51-53: Simplify dict access in load hook.

Use get() to avoid double lookup.

-    if key not in state_dict:
-        return
-    p_to_load = state_dict[key]
+    p_to_load = state_dict.get(key)
+    if p_to_load is None:
+        return

[Based on static analysis hint RUF019]


502-510: Debug message uses mutated node.args twice.

The message claims to show original and new args but prints the same; optional nit.

-    ad_logger.debug(
-        f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
-    )
+    ad_logger.debug(f"Updated node {node} with sharded arguments {args}.")
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0019d99 and 74a8a10.

📒 Files selected for processing (11)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (18 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (19 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
🧠 Learnings (1)
📚 Learning: 2025-10-20T17:07:18.745Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py:98-116
Timestamp: 2025-10-20T17:07:18.745Z
Learning: In NemotronH models (tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py), the gate (self.gate) returns topk_indices and topk_weights that are already in the correct shape to be passed directly to torch_ops.auto_deploy.torch_moe without needing to reshape them when hidden_states is flattened.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
🧬 Code graph analysis (8)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • extract_param_names_from_node (155-180)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • extract_param_names_from_node (155-180)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
  • WeightShardingInfo (567-627)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
  • AutoModelForCausalLMFactory (87-455)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (5)
  • bfs (386-419)
  • filtered_nodes (219-267)
  • is_linear_op (270-280)
  • is_op (193-216)
  • subgraph (534-588)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (16)
  • BMMShardingInfo (839-935)
  • EPShardingInfo (1096-1119)
  • LayerType (560-564)
  • ParameterUpdateInfo (630-644)
  • ShardingConfig (1203-1301)
  • ShardingDim (1194-1200)
  • ShardingSource (1187-1191)
  • ShardingTransformInfo (524-557)
  • SplitDimension (513-521)
  • WeightShardingInfo (567-627)
  • check_and_apply (548-557)
  • add (1240-1250)
  • from_node (578-583)
  • from_node (1103-1108)
  • validate_config (1252-1298)
  • get_predefined_config (1300-1301)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (2)
  • TransformInfo (121-183)
  • get (528-530)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • bfs (386-419)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
  • extract_param_names_from_node (155-180)
  • is_linear_op (270-280)
  • is_op (193-216)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (6)
  • bfs (386-419)
  • extract_param_names_from_node (155-180)
  • is_linear_op (270-280)
  • is_op (193-216)
  • num_users_of_weight_node (149-152)
  • subgraph (534-588)
tensorrt_llm/logger.py (2)
  • debug (144-145)
  • warning (132-133)
🪛 Ruff (0.14.1)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py

116-116: Unused function argument: args

(ARG001)


116-116: Unused function argument: kwargs

(ARG001)

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

381-381: Prefer next(iter(weight_node.users)) over single element slice

Replace with next(iter(weight_node.users))

(RUF015)


430-430: Unused function argument: gm

(ARG001)


454-454: Unpacked variable depth is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


486-486: Prefer next(iter(node.users)) over single element slice

Replace with next(iter(node.users))

(RUF015)

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

572-572: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

68-68: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


83-83: Unpacked variable depth is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


92-92: Unnecessary key check before dictionary access

Replace with dict.get

(RUF019)


234-234: Unused function argument: add_dist

(ARG001)


239-239: Unused function argument: quantization_cb

(ARG001)


265-265: Unpacked variable depth is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


504-504: Unnecessary key check before dictionary access

Replace with dict.get

(RUF019)


638-638: Unused method argument: gm

(ARG002)


642-642: Unused method argument: gm

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (8)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

17-17: LGTM!

The rename from extract_param_names_from_lin_node to extract_param_names_from_node reflects the generalization of parameter extraction to handle parametrized nodes beyond just linear operations. The function signature and usage patterns remain compatible.

Also applies to: 139-139

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)

16-16: LGTM!

The rename from extract_param_names_from_lin_node to extract_param_names_from_node has been applied consistently across both _insert_fused_gemm and _insert_fused_quant_gemm functions. This aligns with the broader refactor to generalize parameter extraction.

Also applies to: 39-39, 131-131

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

77-83: LGTM!

The configuration changes enable the new sharding workflow described in the PR objectives:

  • sharding_source: ['heuristic'] enables automatic sharding detection
  • support_partial_config: true allows partial sharding configurations
  • sharding_dims: ['ssm', 'tp', 'ep', 'bmm'] adds support for SSM (Mamba) layer sharding

These changes align with the PR's goal of adding tensor-parallel sharding support for Mamba layers.

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (2)

321-323: LGTM!

Correctly updated to unpack the new bfs return value (Node, int) and discard the depth.


387-387: LGTM!

Correctly updated to unpack the new bfs return value.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

18-18: LGTM!

The test file has been consistently updated to reflect the API change from TPShardingInfo to WeightShardingInfo. All constructor calls and field accesses (tp_transformsweight_sharding_transforms) have been updated appropriately. The test logic and expectations remain unchanged.

Also applies to: 275-283, 296-304, 310-318, 354-354

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

11-11: LGTM!

The rename from extract_param_names_from_lin_node to extract_param_names_from_node has been applied consistently with the changes in other files. The usage in should_skip_quantization remains compatible.

Also applies to: 120-120

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

742-752: Quantized TP apply paths OK.

The FP8/FP4 sharding callbacks are correctly threaded into _shard_parameter_node.

Also applies to: 810-820

Comment on lines +116 to +127
def _set_sharding_config_patched(self, *args, **kwargs):
self._sharding_config["head_dim"] = 128
self._sharding_config["tp_plan"] = {
"in_proj": "mamba",
"out_proj": "rowwise",
# "up_proj": "colwise",
# "down_proj": "rowwise",
# "*": "gather",
}


AutoModelForCausalLMFactory._set_sharding_config = _set_sharding_config_patched
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Consider calling the original implementation or using the model_config.

This patch completely replaces _set_sharding_config without calling the original implementation. The original method in AutoModelForCausalLMFactory._set_sharding_config (tensorrt_llm/_torch/auto_deploy/models/hf.py:86-454) performs several steps: it initializes head_dim, checks for base_model_tp_plan, computes head_dim from model_config, and sets num_hidden_layers. This patch discards all that logic.

Additionally, the unused *args, **kwargs indicate that the model_config parameter is being ignored, and the hardcoded head_dim=128 may not be appropriate for all model variants.

Consider one of the following approaches:

  1. Call the original implementation first, then override specific fields:
 def _set_sharding_config_patched(self, *args, **kwargs):
+    # Call original to get base config
+    original_method = AutoModelForCausalLMFactory._set_sharding_config
+    original_method.__get__(self, type(self))(*args, **kwargs)
+    # Then override specific fields for Nemotron-H
     self._sharding_config["head_dim"] = 128
     self._sharding_config["tp_plan"] = {
         "in_proj": "mamba",
         "out_proj": "rowwise",
-        # "up_proj": "colwise",
-        # "down_proj": "rowwise",
-        # "*": "gather",
     }
  1. Or extract and use the model_config parameter:
-def _set_sharding_config_patched(self, *args, **kwargs):
+def _set_sharding_config_patched(self, model_config):
     self._sharding_config["head_dim"] = 128
+    # Validate head_dim matches model_config if available
+    if hasattr(model_config, 'head_dim') and model_config.head_dim is not None:
+        assert model_config.head_dim == 128, f"Expected head_dim=128, got {model_config.head_dim}"
     self._sharding_config["tp_plan"] = {
         "in_proj": "mamba",
         "out_proj": "rowwise",
     }
+    # Preserve num_hidden_layers from original logic
+    if hasattr(model_config, 'num_hidden_layers'):
+        self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers

As per coding guidelines.


🏁 Script executed:

#!/bin/bash
# Find and examine the nemotron_h.py file
fd "nemotron_h.py" -x cat {}

Length of output: 34816


🏁 Script executed:

#!/bin/bash
# Find and examine the original hf.py file with context around _set_sharding_config
fd "hf.py" -path "*/auto_deploy/models/*" -x cat {}

Length of output: 168


🏁 Script executed:

#!/bin/bash
# Search for _set_sharding_config method definitions in the codebase
rg -n "_set_sharding_config" --type py -B5 -A15

Length of output: 8010


🏁 Script executed:

#!/bin/bash
# Check what uses _sharding_config
rg "_sharding_config" --type py -B3 -A3 | head -100

Length of output: 8577


🏁 Script executed:

#!/bin/bash
# Search for num_hidden_layers usage in the codebase
rg "num_hidden_layers" --type py -B2 -A2 | grep -E "(sharding_config|nemotron)" | head -50

Length of output: 4640


🏁 Script executed:

#!/bin/bash
# Check NemotronHConfig and related model config
rg -n "class NemotronHConfig|mamba_head_dim|hidden_size" tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py -B2 -A2

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for uses of sharding_config keys in auto_deploy path
rg "sharding_config\[" tensorrt_llm/_torch/auto_deploy/ --type py -B2 -A2

Length of output: 3926


🏁 Script executed:

#!/bin/bash
# Check if NemotronH model config has mamba_head_dim or head_dim attribute
rg -n "mamba_head_dim|head_dim" tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py -B3 -A3 | head -40

Length of output: 230


🏁 Script executed:

#!/bin/bash
# Check AutoModelForCausalLMFactory class to understand sharding_config initialization
rg -n "class AutoModelForCausalLMFactory" tensorrt_llm/_torch/auto_deploy/models/hf.py -A30

Length of output: 1349


Missing num_hidden_layers configuration and unused model_config parameter.

This patch completely replaces the original _set_sharding_config method, losing critical functionality. The original method (hf.py:214–226) initializes head_dim, retrieves base_model_tp_plan from model config, computes head_dim dynamically if needed, and sets num_hidden_layers—all of which are discarded by this patch.

The unused *args, **kwargs and hardcoded head_dim=128 indicate the model_config parameter is being ignored. Additionally, num_hidden_layers is no longer set, which the original method provides for models that define it.

Recommended fix: Call the original implementation first, then override specific fields:

def _set_sharding_config_patched(self, model_config):
    # Call original to preserve base configuration
    super()._set_sharding_config(model_config)
    # Override for Nemotron-H
    self._sharding_config["head_dim"] = 128
    self._sharding_config["tp_plan"] = {
        "in_proj": "mamba",
        "out_proj": "rowwise",
    }

This preserves num_hidden_layers while applying Nemotron-H-specific overrides.

🧰 Tools
🪛 Ruff (0.14.1)

116-116: Unused function argument: args

(ARG001)


116-116: Unused function argument: kwargs

(ARG001)

🤖 Prompt for AI Agents
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py lines 116-127: the
patched _set_sharding_config replaces the original implementation, ignores the
model_config parameter, hardcodes head_dim=128, and drops num_hidden_layers and
any dynamic head_dim logic; fix by accepting the model_config argument, calling
the original _set_sharding_config first to preserve base_model_tp_plan, dynamic
head_dim calculation and num_hidden_layers, then override only the Nemotron-H
specifics (set head_dim to 128 and replace tp_plan entries for in_proj and
out_proj), and remove unused *args/**kwargs or forward them if needed so nothing
from the original configuration is lost.

Comment on lines +281 to +288
split_sizes_0 = split_nodes[0].args[1]
split_sizes_1 = split_nodes[1].args[1]
if split_sizes_0[1] != sum(split_sizes_1):
ad_logger.warning(
f"Split nodes have different sizes. "
f"Skipping Mamba sharding. split_sizes_1={split_sizes_0}, split_sizes_2={split_sizes_1}"
)
return 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

split vs split_with_sizes handling is brittle.

Args[1] is an int for aten.split but a list for aten.split_with_sizes. Current code assumes lists and will crash.

-    split_sizes_0 = split_nodes[0].args[1]
-    split_sizes_1 = split_nodes[1].args[1]
-    if split_sizes_0[1] != sum(split_sizes_1):
+    s0, s1 = split_nodes[0], split_nodes[1]
+    sizes0 = s0.args[1]
+    sizes1 = s1.args[1]
+    if not isinstance(sizes0, (list, tuple)) or not isinstance(sizes1, (list, tuple)):
+        ad_logger.warning("Expected split_with_sizes for Mamba. Skipping.")
+        return 0
+    if sizes0[1] != sum(sizes1):
         ad_logger.warning(
-            f"Split nodes have different sizes. "
-            f"Skipping Mamba sharding. split_sizes_1={split_sizes_0}, split_sizes_2={split_sizes_1}"
+            f"Split nodes have different sizes. Skipping Mamba sharding. "
+            f"split_sizes_1={sizes0}, split_sizes_2={sizes1}"
         )
         return 0
-    fused_weight_dims = {
-        "in_proj": split_sizes_0[0:1] + split_sizes_1 + split_sizes_0[2:],
-        "conv1d": split_sizes_1,
+    fused_weight_dims = {
+        "in_proj": sizes0[0:1] + list(sizes1) + sizes0[2:],
+        "conv1d": list(sizes1),
     }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py around lines
281–288, the code assumes split node arg[1] is always a list (split_with_sizes)
and indexes into it, but aten.split uses an int for arg[1], causing a crash;
update the code to first detect the type of split_sizes_0 and split_sizes_1 (int
vs list), and handle both: if both are lists, keep current sum comparison; if
both are ints, compare them as ints; if one is an int and the other a list,
either normalize by wrapping the int into a sensible list only if you can derive
chunk counts from available tensor metadata, otherwise log a clear warning and
return 0 (skip Mamba sharding) to avoid crashing.

Comment on lines +321 to +336
conv1d_nodes = [
n
for n in subgraph_nodes
if is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d])
]
assert len(conv1d_nodes) == 1, "Expecting exactly one conv1d node"
conv1d_node = conv1d_nodes[0]
# conv1d_node last argument is the number of output channels.
# This one is also sharded, so we need to update this parameter
conv_args = list(conv1d_node.args)
conv_args[-1] = conv1d_node.args[-1] // world_size
sharding_config.add(
ParameterUpdateInfo(
rank=rank, world_size=world_size, target_node=conv1d_node.name, args=tuple(conv_args)
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Likely wrong conv1d arg update.

aten.conv1d has no “out_channels” arg; last arg is groups. Only update if the op is the custom op with explicit out_channels.

-    # conv1d_node last argument is the number of output channels.
-    # This one is also sharded, so we need to update this parameter
-    conv_args = list(conv1d_node.args)
-    conv_args[-1] = conv1d_node.args[-1] // world_size
-    sharding_config.add(
-        ParameterUpdateInfo(
-            rank=rank, world_size=world_size, target_node=conv1d_node.name, args=tuple(conv_args)
-        )
-    )
+    if is_op(conv1d_node, torch.ops.auto_deploy.torch_causal_conv1d):
+        conv_args = list(conv1d_node.args)
+        conv_args[-1] = conv1d_node.args[-1] // world_size  # out_channels
+        sharding_config.add(
+            ParameterUpdateInfo(
+                rank=rank, world_size=world_size, target_node=conv1d_node.name, args=tuple(conv_args)
+            )
+        )
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py around lines
321 to 336, the code unconditionally treats the conv1d node's last arg as
out_channels and divides it by world_size; but aten.conv1d's last arg is groups,
so only the custom op torch.ops.auto_deploy.torch_causal_conv1d (which encodes
out_channels) should be adjusted. Change the logic to detect which op it is: if
the node is the custom torch.ops.auto_deploy.torch_causal_conv1d then compute
conv_args = list(conv1d_node.args), update conv_args[-1] = conv1d_node.args[-1]
// world_size and add the ParameterUpdateInfo; otherwise do not modify the last
arg (leave aten.conv1d unchanged). Also update the inline comment to reflect
this conditional behavior.

Comment on lines +454 to +459
next_lin_node, depth = bfs(linear_nodes[0], is_linear_op, include_root=False)
subgraph_nodes = subgraph(
[linear_nodes],
[next_lin_node],
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Invalid subgraph source type.

Passing [linear_nodes] creates List[List[Node]]; also unhashable in subgraph. Use the first node.

-    subgraph_nodes = subgraph(
-        [linear_nodes],
-        [next_lin_node],
-    )
+    subgraph_nodes = subgraph([linear_nodes[0]], [next_lin_node])
🧰 Tools
🪛 Ruff (0.14.1)

454-454: Unpacked variable depth is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py around lines
454 to 459, the call to subgraph passes [linear_nodes] which creates a
List[List[Node]] and is unhashable; replace that with the first node
(linear_nodes[0]) and ensure you pass proper node lists (e.g.,
subgraph([linear_nodes[0]], [next_lin_node])) so the subgraph source is a list
of Node objects rather than a nested list.

Comment on lines +733 to +741
ssm_nodes = filtered_nodes(gm.graph.nodes, ops=torch.ops.auto_deploy.torch_ssm)
# ssm_nodes = list(ssm_nodes)[1:2]
num_ssm_shards = 0
for ssm_node in ssm_nodes:
# We assume that one ssm node defines a subgraph corresponding
# to a single Mamba layer.
# Find defining previous (in_proj) and next (out_proj) linear nodes.
in_proj_node, _ = bfs(ssm_node, is_linear_op, attr_next="args", include_root=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Operator symbol verification for torch_ssm.

Please confirm the exact op symbol emitted in your FX graphs (e.g., torch.ops.auto_deploy.torch_ssm vs mamba variant). The placeholder in utils uses a different symbol and sets has_ssm=True.


🌐 Web query:

What is the canonical PyTorch custom op name for Mamba SSM used in NVIDIA TensorRT-LLM graphs? (torch_ssm vs torch_ssm_transform)

💡 Result:

The canonical PyTorch custom op name is torch_ssm. [1][2]

Sources:

  • mamba-ssm package (official implementation/exposed ops). [1]
  • NVIDIA TensorRT-LLM functional docs referencing the Mamba SSM functional. [2]


---

🏁 Script executed:

```shell
#!/bin/bash
# Search for all torch_ssm variants in the codebase
rg -n "torch_ssm" --type py -C 3

Length of output: 16276


🏁 Script executed:

#!/bin/bash
# Search specifically for torch_ssm_transform (if it exists)
rg -n "torch_ssm_transform" --type py -C 3

Length of output: 2877


🏁 Script executed:

#!/bin/bash
# Find utils files and check for symbol/has_ssm definitions
rg -n "has_ssm|torch_ssm" --type py -C 2 | grep -E "(utils|has_ssm)" -A 2 -B 2

Length of output: 2071


🏁 Script executed:

#!/bin/bash
# Look for auto_deploy module structure and op definitions
fd "utils" --type f --extension py | grep -E "(auto_deploy|utils)" | head -20

Length of output: 776


🏁 Script executed:

#!/bin/bash
# Search for op definitions or registrations in auto_deploy
rg -n "ops.auto_deploy" --type py -C 2

Length of output: 168847


Symbol is correct in sharding.py; fix incomplete validation in sharding_utils.py (lines 276–277).

The torch.ops.auto_deploy.torch_ssm symbol used in sharding.py (line 733) is correct and canonical. However, the validation logic in sharding_utils.py (lines 276–277) is incomplete: the check for SSM nodes is commented out and has_ssm is unconditionally set to True, bypassing detection. Replace the hardcoded assignment with an actual check:

# has_ssm = any(is_op(n, torch.ops.auto_deploy.torch_ssm) for n in subgraph_nodes)
has_ssm = any(is_op(n, torch.ops.auto_deploy.torch_ssm) for n in subgraph_nodes)

Also update the warning message at line 285 to reference torch_ssm instead of the outdated torch_ssm_transform.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/transform/library/sharding_utils.py around
lines 276–277 (and update at line 285), replace the hardcoded has_ssm = True
with an actual detection: set has_ssm = any(is_op(n,
torch.ops.auto_deploy.torch_ssm) for n in subgraph_nodes) so SSM nodes are
properly detected, and change the warning message at line 285 to reference
"torch_ssm" (the canonical symbol) instead of the outdated
"torch_ssm_transform".

Comment on lines +67 to 70
def _validate_sharded_shapes(
node: Node, fused_weight_dims: Optional[list] = None, world_size: int = None
) -> None:
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typing fixes for Python 3.8 and unused variable.

  • Use Optional[int] for world_size.
  • Prefix unused depth with underscore.
-def _validate_sharded_shapes(
-    node: Node, fused_weight_dims: Optional[list] = None, world_size: int = None
-) -> None:
+def _validate_sharded_shapes(
+    node: Node, fused_weight_dims: Optional[list] = None, world_size: Optional[int] = None
+) -> None:
@@
-    next_lin_node, depth = bfs(node, is_linear_op, include_root=False)
+    next_lin_node, _depth = bfs(node, is_linear_op, include_root=False)

[Based on static analysis hints RUF013, RUF059]

Also applies to: 83-83

🧰 Tools
🪛 Ruff (0.14.1)

68-68: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py around lines 67 to 70
(and similarly at line 83), the function signature uses world_size: int and
there is an unused variable depth; change the type annotation to world_size:
Optional[int] to be compatible with Python 3.8 typing, and rename the unused
variable to _depth (prefix with an underscore) wherever it appears to silence
unused-variable warnings; update any corresponding type hints/imports if needed.

Comment on lines +124 to +136
def shard_weight_tensor(
gm: GraphModule,
node: Node,
weight_tensor: torch.Tensor,
param_key: str,
dim: int,
rank: int,
world_size: int,
add_dist: bool = False,
min_local_shape: int = 1,
quantization_cb: Optional[
Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None]
] = None,
) -> None:
"""Replace the matmul node with a new matmul node that accepts sharded weights.

The state_dict is also updated to contain the sharded weights.
fused_weight_dims: Optional[list] = None,
requires_grad: bool = False,
update_param: bool = True,
) -> Tuple[torch.Tensor, torch.Size]:
"""Shard a weight tensor across ranks and register load hook.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add custom_shard_fn support and wire it.

BMMShardingInfo.apply calls shard_weight_tensor(custom_shard_fn=...), but the function doesn’t accept it. This will raise at runtime.

-def shard_weight_tensor(
+def shard_weight_tensor(
     gm: GraphModule,
     weight_tensor: torch.Tensor,
     param_key: str,
     dim: int,
     rank: int,
     world_size: int,
     min_local_shape: int = 1,
     fused_weight_dims: Optional[list] = None,
     requires_grad: bool = False,
     update_param: bool = True,
-) -> Tuple[torch.Tensor, torch.Size]:
+    custom_shard_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+) -> Tuple[torch.Tensor, torch.Size]:
@@
-    # Handle fused weights
-    if fused_weight_dims is not None:
+    # Choose split function: custom > fused > simple
+    if custom_shard_fn is not None:
+        f_split = custom_shard_fn
+    elif fused_weight_dims is not None:
@@
-    else:
-        f_split = split_tensor
+    else:
+        f_split = split_tensor
@@
     gm._register_load_state_dict_pre_hook(
         partial(
             _load_hook,
             f_split=f_split,
             param_key=param_key,
             param_shape=sharded_shape,
         )
     )

Also applies to: 173-193, 197-215

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py around lines 124-136
(and similarly update the functions at 173-193 and 197-215), the
shard_weight_tensor signature and related helper functions must accept a new
optional custom_shard_fn parameter because callers (e.g., BMMShardingInfo.apply)
pass custom_shard_fn which currently causes a runtime error; add
custom_shard_fn: Optional[Callable[..., torch.Tensor]] = None to the function
signatures, thread that parameter through to where the tensor is sliced/sharded
and to any hook registration code, and if provided call
custom_shard_fn(weight_tensor, dim, rank, world_size, ...) to produce the local
shard (fall back to existing slicing logic when None); ensure type
hints/defaults and any docstrings are updated and tests cover both default and
custom-shard cases.

Comment on lines +217 to +225
def get_all_weights_in_subgraph(
sources: list[Node],
sinks: list[Node],
):
"""Get all weight nodes (get_attr nodes) in the subgraph between sources and sinks."""
weight_nodes = subgraph(
sources, sinks, include_boundary_nodes=False, include=lambda n: n.op == "get_attr"
)
return weight_nodes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Python 3.8 typing and boundary filtering dependency.

Use List[Node] for 3.8; relies on subgraph fix in node_utils.

-def get_all_weights_in_subgraph(
-    sources: list[Node],
-    sinks: list[Node],
-):
+def get_all_weights_in_subgraph(
+    sources: List[Node],
+    sinks: List[Node],
+):

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +319 to +325
conv1d_node = conv1d_nodes[0]
# conv1d_node last argument is the number of output channels.
# This one is also sharded, so we need to update this parameter
conv_args = list(conv1d_node.args)
conv_args[-1] = conv1d_node.args[-1] // world_size
conv1d_node.args = tuple(conv_args)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid changing aten.conv1d args.

Same issue as in library: only update out_channels for custom op; skip for aten.conv1d.

-    conv_args = list(conv1d_node.args)
-    conv_args[-1] = conv1d_node.args[-1] // world_size
-    conv1d_node.args = tuple(conv_args)
+    if is_op(conv1d_node, torch.ops.auto_deploy.torch_causal_conv1d):
+        conv_args = list(conv1d_node.args)
+        conv_args[-1] = conv1d_node.args[-1] // world_size  # out_channels
+        conv1d_node.args = tuple(conv_args)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py around lines 319 to
325, the code unconditionally modifies conv1d_node.args to divide the last arg
by world_size; change this so we do not alter aten.conv1d arguments: detect the
operator (e.g., skip if conv1d_node.target == "aten::conv1d" or similar) and
only update the out_channels argument for the custom conv1d op(s) that require
sharding; leave aten.conv1d args untouched (no in-place mutation) and ensure any
custom-op path performs the // world_size adjustment as currently implemented.

Comment on lines +1187 to +1201
class ShardingSource(Enum):
"""Enum for sharding source."""

HEURISTIC = "heuristic"
FACTORY = "factory"


class ShardingDim(Enum):
"""Enum for sharding dimension."""

SSM = "ssm"
TP = "tp"
EP = "ep"
BMM = "bmm"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Config model consistency and validation gaps.

  • sharding_dims should be List[ShardingDim] (Enums), not List[str].
  • Add the “mamba” mode to supported_modes to allow factory-driven Mamba sharding.
  • Prefer add() over direct appends to avoid duplicates (used elsewhere).
-    sharding_dims: List[str] = Field(default_factory=list)
+    sharding_dims: List[ShardingDim] = Field(default_factory=list)
@@
-        supported_modes = {
+        supported_modes = {
             "colwise",
             "rowwise",
             "gather",
+            "mamba",
         }

Also ensure call sites use Enum membership (see sharding.py fix). [Based on learnings]

Also applies to: 1213-1220, 1240-1251

@greg-kwasniewski1 greg-kwasniewski1 self-assigned this Oct 27, 2025
@greg-kwasniewski1 greg-kwasniewski1 marked this pull request as draft October 27, 2025 12:55
@greg-kwasniewski1 greg-kwasniewski1 added the AutoDeploy <NV> AutoDeploy Backend label Oct 27, 2025
@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22647 [ run ] triggered by Bot. Commit: 74a8a10

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22647 [ run ] completed with state SUCCESS. Commit: 74a8a10
/LLM/main/L0_MergeRequest_PR pipeline #17072 completed with status: 'FAILURE'

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
…cache/NVFP4 + BF16 KV cache (NVIDIA#8405)"

This reverts commit e47c787.

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
@greg-kwasniewski1 greg-kwasniewski1 force-pushed the gk/sharding_mamba_rebased branch from 8d6780c to 8cd6208 Compare October 28, 2025 16:05
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
@lucaslie
Copy link
Member

see #8744

@lucaslie lucaslie closed this Oct 29, 2025
@github-project-automation github-project-automation bot moved this from Backlog to Done in AutoDeploy Board Oct 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

AutoDeploy <NV> AutoDeploy Backend

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

3 participants