Skip to content

[#9236][feature] Make sharing of activation_type across SW layers more robust#9238

Merged
nzmora-nvidia merged 6 commits intoNVIDIA:mainfrom
nzmora-nvidia:users/nzmora/robust_moe_activation
Nov 20, 2025
Merged

[#9236][feature] Make sharing of activation_type across SW layers more robust#9238
nzmora-nvidia merged 6 commits intoNVIDIA:mainfrom
nzmora-nvidia:users/nzmora/robust_moe_activation

Conversation

@nzmora-nvidia
Copy link
Collaborator

@nzmora-nvidia nzmora-nvidia commented Nov 17, 2025

C++, Python and Python MoE layer all share the definition of ActivationType.
Currently this is done thru redefinition which is fragile and can break when adding new activation function types.

tensorrt_llm/_torch/utils.py
cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
=>
tensorrt_llm/layers/moe.py
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Summary by CodeRabbit

  • New Features

    • Introduced Relu2 as a new activation function option for Mixture of Experts (MOE) layer configurations, expanding available activation choices.
  • Tests

    • Re-enabled multiple integration tests that were previously skipped, broadening automated test coverage and validation scope.
  • Chores

    • Refactored internal activation type value assignments and kernel dispatch mechanisms to ensure consistency across the framework.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 17, 2025

📝 Walkthrough

Walkthrough

The changes reorder and reindex the ActivationType enum across C++ and Python, moving InvalidType to position 0 and renumbering all activation types accordingly. The MOE kernel activation dispatch in C++ is refactored from a function-pointer array lookup to explicit switch-case logic. Python activation mappings are updated to use enum-based values, and a new "relu2" mapping is added. Test skip entries are removed to enable those tests.

Changes

Cohort / File(s) Summary
ActivationType enum reordering
cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h, tensorrt_llm/_torch/utils.py
Reindexed ActivationType enum: moved InvalidType from position 8 to 0, shifted all activation types (Gelu, Relu, Silu, Swiglu, Geglu, SwigluBias, Identity, Relu2) to positions 1–8. Changes numeric representation while preserving semantic order.
MOE kernel activation dispatch refactoring
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
Replaced function-pointer array-based dispatch with explicit switch-case logic to select doActivationKernel specializations. Lambda now has explicit return type and includes default case with runtime validation.
Activation mapping enum adoption
tensorrt_llm/layers/moe.py
Updated activation_str_to_int_map to use ActivationType enum values instead of hardcoded integers. Added new "relu2" activation mapping entry.
Test enabling
tests/integration/test_lists/waives.txt
Removed multiple SKIP entries to re-enable previously skipped tests.

Sequence Diagram

sequenceDiagram
    participant Caller as Caller Code
    participant OldDispatch as Old: Array Lookup
    participant NewDispatch as New: Switch-Case
    participant Kernel as doActivationKernel

    Note over OldDispatch: Previous Implementation
    Caller->>OldDispatch: fn_list[activation_type]()
    OldDispatch->>Kernel: Function pointer from array
    Kernel-->>Caller: Result

    Note over NewDispatch: New Implementation
    Caller->>NewDispatch: switch(activation_type)
    NewDispatch->>Kernel: Case branch returns<br/>specialized kernel
    Kernel-->>Caller: Result

    Note over NewDispatch: Invalid case
    NewDispatch->>NewDispatch: default: assert/error
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Special attention areas:
    • Verify ActivationType enum value changes propagate correctly across all dependent code paths (C++ and Python layers).
    • Review the switch-case dispatch logic in moe_kernels.cu to ensure all activation types are properly mapped and no cases are missed.
    • Confirm that the enum reindexing does not break serialization, deserialization, or cross-module communication where numeric values are persisted.
    • Validate that the removed test skips will pass with the enum changes in place.

Possibly related issues

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ⚠️ Warning PR description is minimal and incomplete. Required sections missing: no formal PR title, empty Description and Test Coverage sections, and PR Checklist incomplete. Provide a proper PR title following [type] format, fill Description section explaining the issue and solution, document relevant test coverage, and complete the PR Checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically identifies the main change: making activation_type sharing across software layers more robust, which aligns with the actual changeset involving ActivationType enum reordering and refactoring.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining why this PR is needed, why this solution was chosen, and what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2d6289b and 1617395.

📒 Files selected for processing (5)
  • cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1 hunks)
  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (1 hunks)
  • tensorrt_llm/_torch/utils.py (1 hunks)
  • tensorrt_llm/layers/moe.py (2 hunks)
  • tests/integration/test_lists/waives.txt (0 hunks)
💤 Files with no reviewable changes (1)
  • tests/integration/test_lists/waives.txt
🧰 Additional context used
🧠 Learnings (17)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
📚 Learning: 2025-08-27T14:23:55.566Z
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 7294
File: tensorrt_llm/_torch/modules/rms_norm.py:17-17
Timestamp: 2025-08-27T14:23:55.566Z
Learning: The TensorRT-LLM project requires Python 3.10+ as evidenced by the use of TypeAlias from typing module, match/case statements, and union type | syntax throughout the codebase, despite some documentation still mentioning Python 3.8+.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-08-14T15:38:01.771Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/pybind/thop/bindings.cpp:55-57
Timestamp: 2025-08-14T15:38:01.771Z
Learning: In TensorRT-LLM Python bindings, tensor parameter collections like mla_tensor_params and spec_decoding_tensor_params are kept as required parameters without defaults to maintain API consistency, even when it might affect backward compatibility.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-08-09T02:04:49.623Z
Learnt from: Fridah-nv
Repo: NVIDIA/TensorRT-LLM PR: 6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.623Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-08-14T15:43:23.107Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: tensorrt_llm/_torch/attention_backend/trtllm.py:259-262
Timestamp: 2025-08-14T15:43:23.107Z
Learning: In TensorRT-LLM's attention backend, tensor parameters in the plan() method are assigned directly without validation (dtype, device, contiguity checks). This maintains consistency across all tensor inputs and follows the pattern of trusting callers to provide correctly formatted tensors.

Applied to files:

  • tensorrt_llm/layers/moe.py
📚 Learning: 2025-08-21T02:39:12.009Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 7104
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1475-1480
Timestamp: 2025-08-21T02:39:12.009Z
Learning: The min latency mode functionality in TensorRT-LLM MOE kernels (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu) is deprecated and no longer being maintained/updated, as confirmed by djns99. Bug reports and optimization suggestions for the computeStridesTmaWarpSpecializedLowLatencyKernel and related min latency code paths should be deprioritized.

Applied to files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
📚 Learning: 2025-08-08T22:03:40.707Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.

Applied to files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
📚 Learning: 2025-08-19T03:35:20.866Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4616-4626
Timestamp: 2025-08-19T03:35:20.866Z
Learning: In the MOE profiler TMA workspace preparation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu), the overlapping of TMA WS regions for NONE and FINALIZE variants is deliberate design to save memory space, as confirmed by djns99. The comment "reuse the same pointers to save space" reflects this intentional behavior.

Applied to files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
📚 Learning: 2025-08-21T02:41:10.565Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 7104
File: cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h:141-145
Timestamp: 2025-08-21T02:41:10.565Z
Learning: In TensorRT-LLM MOE GEMM kernels (cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h), the stride_act and stride_weight pointers in TmaWarpSpecializedGroupedGemmInput are intentionally declared as void* rather than typed pointers because the actual stride type is determined at runtime based on factors like the swap_ab flag and layout decisions. This runtime type determination makes compile-time type safety impossible, so void* is the correct approach.

Applied to files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
📚 Learning: 2025-08-08T05:06:31.596Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:36-36
Timestamp: 2025-08-08T05:06:31.596Z
Learning: CUTLASS extension files (under cpp/tensorrt_llm/cutlass_extensions/) follow CUTLASS coding style conventions, including using #pragma once instead of TRTLLM_ prefixed header guards, even though they are .hpp files.

Applied to files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
📚 Learning: 2025-08-22T01:54:35.850Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 7104
File: cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h:999-1000
Timestamp: 2025-08-22T01:54:35.850Z
Learning: The `internal_cutlass_kernels` directory in TensorRT-LLM is a mirror of an internal NVIDIA repository and maintains its own implementation and API that may diverge from the public `cutlass_kernels` version. API inconsistencies between these two directories are intentional and by design, not bugs to be fixed.

Applied to files:

  • cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
🧬 Code graph analysis (1)
tensorrt_llm/layers/moe.py (2)
cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h (2)
  • tensorrt_llm (19-37)
  • ActivationType (24-37)
tensorrt_llm/_torch/utils.py (1)
  • ActivationType (36-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/layers/moe.py (2)

23-23: LGTM: Centralized enum import improves maintainability.

Importing ActivationType from the centralized utils module enables enum-based activation mappings instead of hardcoded integers, making the code more robust when new activation types are added.


51-62: LGTM: Enum-based mapping with new relu2 support.

The activation mapping now uses int(ActivationType.*) values instead of hardcoded integers, making it robust to enum reorderings. The addition of "relu2" expands the supported activation functions.

tensorrt_llm/_torch/utils.py (1)

36-45: Verification complete: ActivationType enum reordering is safe.

The integer value reassignment poses no breaking change risk. Verification confirmed:

  • All Python code uses symbolic enum member names (e.g., ActivationType.Swiglu), not hardcoded integers—no code hardcodes values 0-8 for activation types
  • Serialization code does not store raw activation type integers
  • C++ code uses symbolic enum members in switch statements and will work correctly after recompilation
  • Integration tests and model loading code all reference activation types symbolically

The enum member names will automatically resolve to the new integer values upon recompilation, and existing code will continue to function correctly.

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (2)

2297-2330: Good refactoring to switch-case based activation dispatch.

The change from array-based lookup to explicit switch-case dispatch improves code maintainability and clarity. The switch statement makes it immediately clear which activation types are supported and how they map to kernel specializations. The explicit function pointer return type, while verbose, eliminates any ambiguity about the lambda's contract.


2325-2327: Relu2 activation support added correctly.

The addition of ActivationType::Relu2 follows the established pattern for ungated activation functions and is consistent with the PR's objective to make activation type sharing more robust across layers.

@nzmora-nvidia
Copy link
Collaborator Author

/bot run

@nzmora-nvidia nzmora-nvidia enabled auto-merge (squash) November 17, 2025 23:42
@tensorrt-cicd
Copy link
Collaborator

PR_Github #24809 [ run ] triggered by Bot. Commit: cae980d

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24809 [ run ] completed with state SUCCESS. Commit: cae980d
/LLM/main/L0_MergeRequest_PR pipeline #18722 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

Copy link
Collaborator

@djns99 djns99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some comments in the mixtureOfExpertsBackendBenchmark that are incorrect now. Not a big deal to fix since Im really the only one who uses it, but might be good to update that, probably worth replacing it with a string lookup instead

@nzmora-nvidia
Copy link
Collaborator Author

There are some comments in the mixtureOfExpertsBackendBenchmark that are incorrect now. Not a big deal to fix since Im really the only one who uses it, but might be good to update that, probably worth replacing it with a string lookup instead

@djns99 are you referring to this?
"- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = " "swiglu\n"

kaiyux and others added 4 commits November 19, 2025 04:16
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
…s layers

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
@nzmora-nvidia nzmora-nvidia force-pushed the users/nzmora/robust_moe_activation branch from cae980d to ad6b955 Compare November 19, 2025 12:17
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
@nzmora-nvidia
Copy link
Collaborator Author

There are some comments in the mixtureOfExpertsBackendBenchmark that are incorrect now. Not a big deal to fix since Im really the only one who uses it, but might be good to update that, probably worth replacing it with a string lookup instead

@djns99 are you referring to this? "- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = " "swiglu\n"

Updated the user instruction.

@nzmora-nvidia
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25047 [ run ] triggered by Bot. Commit: f91baf0

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25047 [ run ] completed with state SUCCESS. Commit: f91baf0
/LLM/main/L0_MergeRequest_PR pipeline #18929 completed with status: 'FAILURE'

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
@nzmora-nvidia
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25098 [ run ] triggered by Bot. Commit: 2c5154d

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25098 [ run ] completed with state SUCCESS. Commit: 2c5154d
/LLM/main/L0_MergeRequest_PR pipeline #18973 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@nzmora-nvidia nzmora-nvidia merged commit 1d6fbbf into NVIDIA:main Nov 20, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants