Skip to content

Conversation

@bmarimuthu-nv
Copy link

@bmarimuthu-nv bmarimuthu-nv commented Jan 8, 2026

Summary by CodeRabbit

  • New Features

    • Added export support for MiniMax-M2 MoE models with automatic module patching
    • Introduced graph dumping capability for debugging transformed models via environment variable
  • Improvements

    • Enhanced fused-weight detection to handle multiple patterns (split, slice, chunk)
    • Added intermediate attention weight sharding support for distributed inference
    • Improved node renaming with hierarchical module structure information
    • Enhanced error messages with diagnostic context
  • Tests

    • Added validation tests for MiniMax-M2 MoE patch functionality

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Fixes #10245

  • Adds MoE patch for Minimax M2 MoE layer
  • Fixes qk_norm weight sharding
    • q/k/v -> qk_norm -> attn -> o_proj. The qk_norm weights also need to be sharded on head dim same way as qkv.
  • Adds some debug utils to dump graph IR after every transform based on the env var: AD_DUMP_GRAPHS_DIR=<dir to dump graphs>

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

# check if we need to re-combine outputs
if num_prefill > 0 and num_decode > 0:
y = torch.empty_like(q)
y = torch.empty(q.shape, dtype=q.dtype, device=q.device)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FakeTensor promises contiguous tensor. But the q can be non-contiguous (based on the model arch/presence of transpose) but end up contiguous at the end of the layer eventually. So we don't inherit the strides this way (and hence contiguous).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flashinfer doesn't work with non-contiguous inputs. If q is non-contiguous at this point, we need ot make it contiguous before passing it into the kernels!

assert all(
s.args[1] == 2 for s in filtered_nodes(linear_node.users, ops=torch.ops.aten.slice)
), "Expecting slice nodes to slice tensor over dim=2"
fused_weight_dims = [s.args[3] - s.args[2] for s in linear_node.users]
Copy link
Author

@bmarimuthu-nv bmarimuthu-nv Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: slice ops are expected to be direct users of linear_node. But the slice_nodes obtained above can be anywhere downstream of linear.

Fix: Make the check strict - see if slice is directly after linear (as assumed by the logic here)

canonicalize_graph(gm)


def _rename_nodes_with_module_hierarchy(gm: fx.GraphModule) -> None:
Copy link
Author

@bmarimuthu-nv bmarimuthu-nv Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Util to make reading the graph/IR easy. Node names reflect the hierarchy they belong to. Example model_layers_0_self_attn_q_norm_to_9 IR:

%model_layers_0_self_attn_q_proj_torch_linear_simple = auto_deploy.torch_linear_simple.default(%model_layers_0_input_layernorm_mul_3 : s44xs70x3072 : torch.bfloat16, %model_layers_0_self_attn_q_proj_weight : 6144x3072 : torch.bfloat16, None) : s44xs70x6144 : torch.bfloat16
%model_layers_0_self_attn_k_proj_torch_linear_simple_1 = auto_deploy.torch_linear_simple.default(%model_layers_0_input_layernorm_mul_3 : s44xs70x3072 : torch.bfloat16, %model_layers_0_self_attn_k_proj_weight : 1024x3072 : torch.bfloat16, None) : s44xs70x1024 : torch.bfloat16
%model_layers_0_self_attn_v_proj_torch_linear_simple_2 = auto_deploy.torch_linear_simple.default(%model_layers_0_input_layernorm_mul_3 : s44xs70x3072 : torch.bfloat16, %model_layers_0_self_attn_v_proj_weight : 1024x3072 : torch.bfloat16, None) : s44xs70x1024 : torch.bfloat16
%model_layers_0_self_attn_q_norm_to_9 = aten.to.dtype(%model_layers_0_self_attn_q_proj_torch_linear_simple : s44xs70x6144 : torch.bfloat16, torch.float32) : s44xs70x6144 : torch.float32

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we think this could get too verbose? Otherwise, I think it's a great idea

self._transform_counter = 0
self._dump_dir_initialized = False

def dump_graph(self, mod: nn.Module, transform_name: str, stage: str) -> None:
Copy link
Author

@bmarimuthu-nv bmarimuthu-nv Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dump an SSA-style IR with Dtype annotation (default torch graph doesn't show dtype, shape annotations)
Example:

%model_layers_0_self_attn_q_proj_torch_linear_simple = auto_deploy.torch_linear_simple.default(%model_layers_0_input_layernorm_mul_3 : s44xs70x3072 : torch.bfloat16, %model_layers_0_self_attn_q_proj_weight : 6144x3072 : torch.bfloat16, None) : s44xs70x6144 : torch.bfloat16
%model_layers_0_self_attn_k_proj_torch_linear_simple_1 = auto_deploy.torch_linear_simple.default(%model_layers_0_input_layernorm_mul_3 : s44xs70x3072 : torch.bfloat16, %model_layers_0_self_attn_k_proj_weight : 1024x3072 : torch.bfloat16, None) : s44xs70x1024 : torch.bfloat16
%model_layers_0_self_attn_v_proj_torch_linear_simple_2 = auto_deploy.torch_linear_simple.default(%model_layers_0_input_layernorm_mul_3 : s44xs70x3072 : torch.bfloat16, %model_layers_0_self_attn_v_proj_weight : 1024x3072 : torch.bfloat16, None) : s44xs70x1024 : torch.bfloat16
%model_layers_0_self_attn_q_norm_to_9 = aten.to.dtype(%model_layers_0_self_attn_q_proj_torch_linear_simple : s44xs70x6144 : torch.bfloat16, torch.float32) : s44xs70x6144 : torch.float32

@bmarimuthu-nv bmarimuthu-nv changed the title [AutoDeploy] Minimax M2 support ##10245 [feat] AutoDeploy: Add Minimax M2 support Jan 8, 2026
@bmarimuthu-nv bmarimuthu-nv changed the title ##10245 [feat] AutoDeploy: Add Minimax M2 support #10245 [feat] AutoDeploy: Add Minimax M2 support Jan 8, 2026
@bmarimuthu-nv bmarimuthu-nv changed the title #10245 [feat] AutoDeploy: Add Minimax M2 support [#10245][feat] AutoDeploy: Add Minimax M2 support Jan 8, 2026
@bmarimuthu-nv
Copy link
Author

@coderabbitai summary

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

✅ Actions performed

Summary regeneration triggered.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

📝 Walkthrough

Walkthrough

This PR introduces Torch export infrastructure for MiniMax-M2 MoE models by adding hierarchical node renaming, module-level dynamic patching, graph dumping for debugging, and enhanced weight sharding detection. It includes corresponding unit tests for MoE patch validation.

Changes

Cohort / File(s) Summary
Export and Node Hierarchy
tensorrt_llm/_torch/auto_deploy/export/export.py
Added _rename_nodes_with_module_hierarchy() helper to construct hierarchical node names from nn_module_stack metadata. Helper is invoked at multiple sites in the export flow. Note: Duplicate function definition detected in the same file. Added import of unique from enum.
Model Patching Infrastructure
tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py
Introduced runtime patching mechanism for MiniMax-M2 MoE export compatibility. Added minimax_m2_moe() forward implementation with torch-export-friendly routing (sigmoid-based weights, top-k expert selection, torch.moe invocation). Global patch applied via get_model_from_config_patched() wrapper around AutoModelForCausalLM.from_config.
Transform and Utility Enhancements
tensorrt_llm/_torch/auto_deploy/transform/interface.py, tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py, tensorrt_llm/_torch/auto_deploy/utils/logger.py
Added graph dumping via ad_logger.dump_graph() in transform interface. Expanded fused-weight detection to handle split/slice/chunk patterns. Introduced _shard_intermediate_attention_weights() helper for column-wise sharding of intermediate weights (q_norm, k_norm). New logging infrastructure: dump_ssa_with_meta(), dtype extraction helper, and environment-driven graph dump configuration.
Diagnostic and Testing
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py, tensorrt_llm/_torch/auto_deploy/utils/node_utils.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_minimax_m2_patches.py
Refactored tensor creation in flashinfer attention from empty_like() to explicit empty() call. Enhanced assertion message in node utilities with diagnostic context. Added new test module validating MiniMax-M2 MoE patch functional equivalence against original HuggingFace implementation.

Sequence Diagrams

sequenceDiagram
    participant App as Application
    participant Config as Config/Model
    participant Patch as Patcher
    participant MoE as MoE Module
    participant Experts as Expert Weights
    participant Output as Output

    App->>Config: AutoModelForCausalLM.from_config()
    Config->>Patch: get_model_from_config_patched()
    Patch->>Config: _from_config_previous()
    Config-->>Patch: model instance
    Patch->>MoE: iterate modules & patch forward
    Patch-->>App: patched model
    App->>MoE: forward(hidden_states)
    MoE->>MoE: flatten input, compute router_logits
    MoE->>MoE: routing_weights = sigmoid(router_logits)
    MoE->>MoE: select top_k_experts by scores
    MoE->>Experts: extract expert weights
    MoE->>MoE: torch.moe(flattened, top_k_weights, experts)
    MoE->>Output: reshape & return (hidden_states, router_logits)
Loading
sequenceDiagram
    participant Export as Export Flow
    participant GraphMod as GraphModule
    participant Rename as Node Renamer
    participant Logger as Graph Logger
    participant Dump as Dump File

    Export->>GraphMod: transform graph
    GraphMod->>Rename: _rename_nodes_with_module_hierarchy()
    Rename->>Rename: iterate call_function nodes
    Rename->>Rename: extract nn_module_stack metadata
    Rename->>Rename: construct hierarchical name
    Rename-->>GraphMod: update node.name
    GraphMod-->>Logger: transformed module
    Logger->>Logger: check AD_DUMP_GRAPHS_DIR env
    Logger->>Logger: dump_ssa_with_meta(mod)
    Logger->>Dump: write SSA IR with dtype annotations
    Dump-->>Export: debug artifacts
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7187afe and b94e609.

📒 Files selected for processing (8)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/export/export.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/logger.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_minimax_m2_patches.py

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bmarimuthu-nv
Copy link
Author

/bot run

Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
@bmarimuthu-nv
Copy link
Author

/bot run

Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
# check if we need to re-combine outputs
if num_prefill > 0 and num_decode > 0:
y = torch.empty_like(q)
y = torch.empty(q.shape, dtype=q.dtype, device=q.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flashinfer doesn't work with non-contiguous inputs. If q is non-contiguous at this point, we need ot make it contiguous before passing it into the kernels!

canonicalize_graph(gm)


def _rename_nodes_with_module_hierarchy(gm: fx.GraphModule) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we think this could get too verbose? Otherwise, I think it's a great idea

Comment on lines +1906 to +1910
def _shard_intermediate_attention_weights(
layer_subgraph: LayerSubgraph,
linear_nodes: List[Node],
transform_container: ShardingTransformContainer,
) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this more conservative? Let's just shard weights associated with torch_rmsnorm

self._transform_counter = 0
self._dump_dir_initialized = False

def dump_graph(self, mod: nn.Module, transform_name: str, stage: str) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Support Minimax model in AutoDeploy

2 participants