Commit 8f624e9
[TRTLLM-11289][feat] Replace DeepSeek router GEMM with CuTe DSL BF16 GEMM (FP32 output)
Enable CuTe DSL BF16 GEMM kernel for DeepseekV3Gate router GEMM on Blackwell.
The router computes BF16 input @ BF16 weight -> FP32 logits, which our
persistent GEMM kernel already supports via FP32 accumulator and FP32 output.
Key changes:
- Support FP32 output dtype in CuteDSLBf16BlackwellGemmRunner (detect from
output tensor instead of hardcoding BF16, add c_dtype to kernel cache key)
- Relax cute_dsl_bf16_gemm_blackwell custom op to accept BF16 or FP32 output
- Add CuTe DSL dispatch in DeepseekV3Gate.forward() gated by
use_cute_dsl_bf16_gemm flag, with fallback to dsv3_router_gemm_op
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>1 parent b7a5e72 commit 8f624e9
File tree
2 files changed
+35
-10
lines changed- tensorrt_llm/_torch
- custom_ops
- models
2 files changed
+35
-10
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4052 | 4052 | | |
4053 | 4053 | | |
4054 | 4054 | | |
4055 | | - | |
| 4055 | + | |
4056 | 4056 | | |
4057 | 4057 | | |
4058 | 4058 | | |
4059 | 4059 | | |
| 4060 | + | |
| 4061 | + | |
| 4062 | + | |
4060 | 4063 | | |
4061 | 4064 | | |
4062 | 4065 | | |
| |||
4083 | 4086 | | |
4084 | 4087 | | |
4085 | 4088 | | |
4086 | | - | |
| 4089 | + | |
4087 | 4090 | | |
4088 | 4091 | | |
4089 | 4092 | | |
| |||
4109 | 4112 | | |
4110 | 4113 | | |
4111 | 4114 | | |
4112 | | - | |
| 4115 | + | |
4113 | 4116 | | |
4114 | 4117 | | |
4115 | 4118 | | |
| |||
4146 | 4149 | | |
4147 | 4150 | | |
4148 | 4151 | | |
| 4152 | + | |
| 4153 | + | |
| 4154 | + | |
4149 | 4155 | | |
4150 | 4156 | | |
4151 | 4157 | | |
| |||
4169 | 4175 | | |
4170 | 4176 | | |
4171 | 4177 | | |
| 4178 | + | |
4172 | 4179 | | |
4173 | 4180 | | |
4174 | 4181 | | |
| |||
4290 | 4297 | | |
4291 | 4298 | | |
4292 | 4299 | | |
4293 | | - | |
| 4300 | + | |
| 4301 | + | |
4294 | 4302 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | | - | |
| 44 | + | |
45 | 45 | | |
46 | 46 | | |
47 | 47 | | |
| |||
852 | 852 | | |
853 | 853 | | |
854 | 854 | | |
| 855 | + | |
855 | 856 | | |
856 | 857 | | |
| 858 | + | |
857 | 859 | | |
858 | 860 | | |
859 | 861 | | |
| |||
878 | 880 | | |
879 | 881 | | |
880 | 882 | | |
881 | | - | |
882 | | - | |
883 | | - | |
884 | | - | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
| 893 | + | |
| 894 | + | |
| 895 | + | |
| 896 | + | |
| 897 | + | |
| 898 | + | |
| 899 | + | |
| 900 | + | |
885 | 901 | | |
886 | 902 | | |
887 | 903 | | |
| |||
947 | 963 | | |
948 | 964 | | |
949 | 965 | | |
950 | | - | |
| 966 | + | |
| 967 | + | |
951 | 968 | | |
952 | 969 | | |
953 | 970 | | |
| |||
0 commit comments