Skip to content

Conversation

@Nekofish-L
Copy link
Contributor

@Nekofish-L Nekofish-L commented Dec 29, 2025

Background

The recent update to support per-block quantization on Blackwell architectures requires converting quantization scales to FP8_E8M0 format and adjusting weights accordingly. However, this conversion process was implemented in a memory-inefficient manner, creating a massive, temporary memory footprint.

For example, when loading a model such as Qwen3-32B-FP8 with a TP2 configuration, the weight loading process alone consumed approximately 18 GB of GPU memory per gpu. However, the additional resmoothing step increased peak memory usage by an extra ~15 GB per GPU, making it impossible to load the model under the memory constraints of RTX5090.

Solution

This PR implements a fused triton kernel for e8m0 resmooth to reduce peak memory footprint.

Summary by CodeRabbit

  • Refactor
    • FP8 quantization utility operations have been refactored and optimized for improved performance and efficiency.
    • Updated function signatures in the FP8 resmoothing utilities; several functions have been removed from the public API. Code using affected functions may require updates to maintain compatibility.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 29, 2025

📝 Walkthrough

Walkthrough

The file replaces Python-based FP8 resmoothing logic with a Triton kernel implementation. The public resmooth_to_fp8_e8m0 function signature changes from (weight, sf) to (weight, weight_scale, block_size), and per_block_cast_to_fp8_e8m0 is removed. A new private _resmooth_kernel performs in-place GPU computation.

Changes

Cohort / File(s) Summary
FP8 Resmoothing Kernel Replacement
tensorrt_llm/quantization/utils/fp8_utils.py
Removed public functions per_block_cast_to_fp8_e8m0 and old resmooth_to_fp8_e8m0. Added private Triton kernel _resmooth_kernel that loads FP8 weights, converts to FP32, computes block-wise max absolute value, derives new scale via exp2(ceil(log2(block_amax/448.0))), and updates weight and scale tensors in-place. New public resmooth_to_fp8_e8m0 accepts (weight, weight_scale, block_size=(128,128)) and launches the kernel.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ⚠️ Warning PR description is incomplete. While it provides background and problem context, required sections 'Description' and 'Test Coverage' are empty. Complete the 'Description' section explaining the solution in short, and fill in 'Test Coverage' with relevant test cases that safeguard the changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: implementing a fused Triton kernel for e8m0 resmoothing to reduce memory footprint, which aligns with the file changes and PR objectives.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/quantization/utils/fp8_utils.py (1)

1-1: Add the required NVIDIA copyright header.

This file is missing the NVIDIA copyright header required by coding guidelines. The header should include the year 2025 (or appropriate year range) for this modification.

As per coding guidelines, all TensorRT-LLM code should contain an NVIDIA copyright header.

🔎 Suggested copyright header

Add at the beginning of the file:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from typing import Optional, Tuple
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 965578c and 84a0240.

📒 Files selected for processing (1)
  • tensorrt_llm/quantization/utils/fp8_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used
Python files should use snake_case naming: some_file.py
Python classes should use PascalCase naming: class SomeClass
Python functions and methods should use snake_case naming: def my_awesome_function():
Python local variables should use snake_case naming: my_variable = ...
Python variable names that start with a number should be prefixed with 'k': k_99th_percentile = ...
Python global variables should use upper snake_case with prefix 'G': G_MY_GLOBAL = ...
Python constants should use upper snake_case naming: MY_CONSTANT = ...
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings in Python for classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible, using the else block for logic

Files:

  • tensorrt_llm/quantization/utils/fp8_utils.py
**/*.{cpp,h,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification

Files:

  • tensorrt_llm/quantization/utils/fp8_utils.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.
Learnt from: venkywonka
Repo: NVIDIA/TensorRT-LLM PR: 6029
File: .github/pull_request_template.md:45-53
Timestamp: 2025-08-27T17:50:13.264Z
Learning: For PR templates in TensorRT-LLM, avoid suggesting changes that would increase developer overhead, such as converting plain bullets to mandatory checkboxes. The team prefers guidance-style bullets that don't require explicit interaction to reduce friction in the PR creation process.
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
📚 Learning: 2025-09-19T21:28:13.751Z
Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.

Applied to files:

  • tensorrt_llm/quantization/utils/fp8_utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

Comment on lines +54 to +92
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wm,
stride_wk,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_k = tl.cdiv(K, BLOCK_K)
pid_m = pid // num_pid_k
pid_k = pid % num_pid_k

s_offset = pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(s_ptr + s_offset)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)

w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)

w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)

# E8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)

tl.store(w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(s_ptr + s_offset, new_scale)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Batch dimension handling is broken in the kernel.

The kernel receives batched 3D tensors (w_view and s_view with shape (num_batches, M, K)) but only receives stride(1) and stride(2). The grid on Line 111-112 multiplies by num_batches, but the kernel doesn't extract the batch ID from pid or use batch strides. This causes incorrect indexing for multi-batch inputs.

🔎 Proposed fix

Add batch handling to the kernel:

 @triton.jit
 def _resmooth_kernel(
     w_ptr,
     s_ptr,
     M,
     K,
+    stride_wb,
     stride_wm,
     stride_wk,
+    stride_sb,
     stride_sm,
     stride_sk,
     BLOCK_M: tl.constexpr,
     BLOCK_K: tl.constexpr,
 ):
     pid = tl.program_id(0)
+    num_pid_m = tl.cdiv(M, BLOCK_M)
     num_pid_k = tl.cdiv(K, BLOCK_K)
-    pid_m = pid // num_pid_k
-    pid_k = pid % num_pid_k
+    num_pid_per_batch = num_pid_m * num_pid_k
+    
+    pid_batch = pid // num_pid_per_batch
+    pid_remainder = pid % num_pid_per_batch
+    pid_m = pid_remainder // num_pid_k
+    pid_k = pid_remainder % num_pid_k
 
-    s_offset = pid_m * stride_sm + pid_k * stride_sk
+    s_offset = pid_batch * stride_sb + pid_m * stride_sm + pid_k * stride_sk
     old_scale = tl.load(s_ptr + s_offset)
 
     rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
     rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
 
     w_mask = (rm[:, None] < M) & (rk[None, :] < K)
-    w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
+    w_offsets = pid_batch * stride_wb + rm[:, None] * stride_wm + rk[None, :] * stride_wk
     w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)

Update the kernel call on Line 114-125:

 _resmooth_kernel[grid](
     w_view,
     s_view,
     M,
     K,
+    w_view.stride(0),
     w_view.stride(1),
     w_view.stride(2),
+    s_view.stride(0),
     s_view.stride(1),
     s_view.stride(2),
     BLOCK_M=BLOCK_M,
     BLOCK_K=BLOCK_K,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wm,
stride_wk,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_k = tl.cdiv(K, BLOCK_K)
pid_m = pid // num_pid_k
pid_k = pid % num_pid_k
s_offset = pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(s_ptr + s_offset)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)
w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)
# E8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)
tl.store(w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(s_ptr + s_offset, new_scale)
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wb,
stride_wm,
stride_wk,
stride_sb,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_k = tl.cdiv(K, BLOCK_K)
num_pid_per_batch = num_pid_m * num_pid_k
pid_batch = pid // num_pid_per_batch
pid_remainder = pid % num_pid_per_batch
pid_m = pid_remainder // num_pid_k
pid_k = pid_remainder % num_pid_k
s_offset = pid_batch * stride_sb + pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(s_ptr + s_offset)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = pid_batch * stride_wb + rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)
w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)
# E8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)
tl.store(w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(s_ptr + s_offset, new_scale)
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 54 to 92, the
kernel ignores the batch dimension so multi-batch inputs index incorrectly;
update the kernel to extract a batch id from tl.program_id(0) and compute
pid_m/pid_k relative to the per‑batch grid, add batch strides (e.g. stride_wb
and stride_sb) as additional kernel arguments, and incorporate them into
s_offset and w_offsets (compute a batch_base for w_ptr and s_ptr using
batch*stride_wb / batch*stride_sb before adding pid_m/pid_k offsets). Also
update the kernel launch to pass num_batches and the new batch strides so each
program id indexes the correct (batch, m, k) block.

Comment on lines +94 to +127
def resmooth_to_fp8_e8m0(
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: tuple[int, int] = (128, 128),
):
assert weight.dtype == torch.float8_e4m3fn
assert weight_scale.dtype == torch.float32

orig_shape = weight.shape
M, K = orig_shape[-2:]
w_view = weight.view(-1, M, K)
s_view = weight_scale.view(-1, weight_scale.shape[-2],
weight_scale.shape[-1])

num_batches = w_view.shape[0]
BLOCK_M, BLOCK_K = block_size

grid = (num_batches * (triton.cdiv(M, BLOCK_M)) *
(triton.cdiv(K, BLOCK_K)), )

_resmooth_kernel[grid](
w_view,
s_view,
M,
K,
w_view.stride(1),
w_view.stride(2),
s_view.stride(1),
s_view.stride(2),
BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_K,
)
# this is an in-place operation, however, we return for simplicity
return weight, weight_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find all calls to resmooth_to_fp8_e8m0
echo "=== Searching for resmooth_to_fp8_e8m0 calls ==="
rg -n 'resmooth_to_fp8_e8m0' --type=py -A 3 -B 1

echo ""
echo "=== Checking file header for copyright ==="
head -20 tensorrt_llm/quantization/utils/fp8_utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3302


Add NVIDIA copyright header to the file.

The file tensorrt_llm/quantization/utils/fp8_utils.py is missing the required NVIDIA copyright header. According to coding guidelines, all TensorRT-LLM code must contain an NVIDIA copyright header with the year of latest meaningful modification. Add the header at the top of the file before the imports.

Regarding the function signature change: All call sites in the codebase correctly use the new signature and are compatible with the default block_size=(128, 128) parameter. No signature updates are required at call sites.

🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 94 to 127: add the
required NVIDIA copyright header at the very top of the file (before any
imports) with the correct year of latest meaningful modification and the
standard NVIDIA license phrasing used in this repo; ensure the header exactly
matches the project's header format and encoding, then save—no changes to the
function signature or any call sites are required since they are already
compatible with the default block_size.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Dec 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants