Skip to content

Commit 8f624e9

Browse files
peaceh-nvclaude
andcommitted
[TRTLLM-11289][feat] Replace DeepSeek router GEMM with CuTe DSL BF16 GEMM (FP32 output)
Enable CuTe DSL BF16 GEMM kernel for DeepseekV3Gate router GEMM on Blackwell. The router computes BF16 input @ BF16 weight -> FP32 logits, which our persistent GEMM kernel already supports via FP32 accumulator and FP32 output. Key changes: - Support FP32 output dtype in CuteDSLBf16BlackwellGemmRunner (detect from output tensor instead of hardcoding BF16, add c_dtype to kernel cache key) - Relax cute_dsl_bf16_gemm_blackwell custom op to accept BF16 or FP32 output - Add CuTe DSL dispatch in DeepseekV3Gate.forward() gated by use_cute_dsl_bf16_gemm flag, with fallback to dsv3_router_gemm_op Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
1 parent b7a5e72 commit 8f624e9

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4052,11 +4052,14 @@ def get_valid_tactics(
40524052
)
40534053
return []
40544054

4055-
# input: [M, K], weight: [N, K]
4055+
# input: [M, K], weight: [N, K], output: [M, N]
40564056
m, k = inputs[0].shape[0], inputs[0].shape[1]
40574057
n = inputs[1].shape[0]
40584058
batch_size = 1
40594059

4060+
# Detect output dtype from the output tensor (supports BF16 and FP32)
4061+
c_dtype_cutlass = _TORCH_TO_CUTLASS_DTYPE[inputs[2].dtype]
4062+
40604063
# Layouts: A is [M, K] K-major, B is [N, K] K-major
40614064
a_major = "k"
40624065
b_major = "k"
@@ -4083,7 +4086,7 @@ def get_valid_tactics(
40834086
if self.__class__.kernel_class.can_implement(
40844087
cutlass.BFloat16, # ab_dtype
40854088
cutlass.Float32, # acc_dtype
4086-
cutlass.BFloat16, # c_dtype
4089+
c_dtype_cutlass, # c_dtype
40874090
use_2cta_instrs,
40884091
mma_tiler_mn,
40894092
cluster_shape_mn,
@@ -4109,7 +4112,7 @@ def forward(
41094112
inputs (List[torch.Tensor]):
41104113
inputs[0]: Input tensor of shape (m, k), dtype: bf16.
41114114
inputs[1]: Weight tensor of shape (n, k), dtype: bf16.
4112-
inputs[2]: Output tensor of shape (m, n), dtype: bf16.
4115+
inputs[2]: Output tensor of shape (m, n), dtype: bf16 or fp32.
41134116
tactic: Tiling and cluster strategy, typically a tuple
41144117
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
41154118
"""
@@ -4146,6 +4149,9 @@ def forward(
41464149
# c_buf is [M, N], permute to [M, N, 1] for cute layout
41474150
c_tmp = c_buf.unsqueeze(-1) # [M, N, 1]
41484151

4152+
# Detect output dtype (supports BF16 and FP32)
4153+
c_dtype_cutlass = _TORCH_TO_CUTLASS_DTYPE[c_tensor.dtype]
4154+
41494155
if not self.use_tvm_ffi:
41504156
a_ptr = make_ptr(
41514157
cutlass.BFloat16,
@@ -4169,6 +4175,7 @@ def forward(
41694175
mma_tiler_mn,
41704176
cluster_shape_mn,
41714177
self.use_tvm_ffi,
4178+
c_dtype_cutlass,
41724179
)
41734180
if cache_key not in self.__class__.kernel_cache:
41744181
if self.use_tvm_ffi:
@@ -4290,5 +4297,6 @@ def _(
42904297
) -> None:
42914298
m, k = mat_a.shape[0], mat_a.shape[1]
42924299
n = mat_b.shape[0]
4293-
assert output.dtype == torch.bfloat16, "CuTe DSL bf16 gemm output dtype must be bf16"
4300+
assert output.dtype in (torch.bfloat16, torch.float32), \
4301+
"CuTe DSL bf16 gemm output dtype must be bf16 or fp32"
42944302
assert output.shape == (m, n), "CuTe DSL bf16 gemm output shape is incorrect"

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from tensorrt_llm._ipc_utils import can_access_peer
4242
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
4343
ConsumableWeightsDict
44-
from tensorrt_llm._utils import get_sm_version
44+
from tensorrt_llm._utils import get_sm_version, is_sm_100f
4545
from tensorrt_llm.functional import PositionEmbeddingType
4646
from tensorrt_llm.mapping import Mapping
4747
from tensorrt_llm.models.modeling_utils import QuantConfig
@@ -852,8 +852,10 @@ def __init__(
852852
fuse_routing_kernel: bool = True,
853853
apply_routing: bool = False,
854854
moe_backend: str = 'CUTLASS',
855+
use_cute_dsl_bf16_gemm: bool = False,
855856
):
856857
super().__init__()
858+
self.use_cute_dsl_bf16_gemm = use_cute_dsl_bf16_gemm
857859
self.weight = nn.Parameter(torch.empty((num_experts, hidden_size),
858860
dtype=dtype),
859861
requires_grad=False)
@@ -878,10 +880,24 @@ def __init__(
878880
is_fused=fuse_routing_kernel)
879881

880882
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
881-
logits = torch.ops.trtllm.dsv3_router_gemm_op(hidden_states,
882-
self.weight.t(),
883-
bias=None,
884-
out_dtype=torch.float32)
883+
if (self.use_cute_dsl_bf16_gemm and is_sm_100f()
884+
and self.weight.dtype == torch.bfloat16):
885+
input_2d = hidden_states.view(-1, hidden_states.shape[-1])
886+
m, k = input_2d.shape
887+
n = self.weight.shape[0]
888+
output = torch.empty(m,
889+
n,
890+
dtype=torch.float32,
891+
device=hidden_states.device)
892+
torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell(
893+
input_2d.contiguous(), self.weight, output)
894+
logits = output.view(*hidden_states.shape[:-1], n)
895+
else:
896+
logits = torch.ops.trtllm.dsv3_router_gemm_op(
897+
hidden_states,
898+
self.weight.t(),
899+
bias=None,
900+
out_dtype=torch.float32)
885901
return logits
886902

887903
def load_weights(self, weights: List[Dict]):
@@ -947,7 +963,8 @@ def __init__(self,
947963
dtype=dtype,
948964
fuse_routing_kernel=True,
949965
apply_routing=False,
950-
moe_backend=model_config.moe_backend)
966+
moe_backend=model_config.moe_backend,
967+
use_cute_dsl_bf16_gemm=model_config.use_cute_dsl_bf16_gemm)
951968
self.experts = create_moe(
952969
num_experts=num_experts,
953970
routing_method=self.gate.routing_method,

0 commit comments

Comments
 (0)