Skip to content

Commit 7eee9a9

Browse files
authored
doc: Update doc for Deepseek min latency (NVIDIA#3717)
* Tidy code Signed-off-by: Zongfei Jing <[email protected]> * Update doc for min latency deepseek Signed-off-by: Zongfei Jing <[email protected]> * Throw exception for RouterKernel when not running on sm90+ Signed-off-by: Zongfei Jing <[email protected]> --------- Signed-off-by: Zongfei Jing <[email protected]>
1 parent 0ae7017 commit 7eee9a9

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/trtllmGenSrc/RoutingKernel.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
698698
#else
699699
__global__ void routingIndicesClusterKernel(KernelParams params)
700700
{
701+
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
701702
}
702703
#endif
703704
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -886,6 +887,8 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesCoopKernel(KernelPar
886887
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
887888
}
888889
}
890+
#else
891+
assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures");
889892
#endif
890893
}
891894

@@ -973,6 +976,8 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesHistogramKernel(Kern
973976
// Reduce histograms with atomics.
974977
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
975978
atomicAdd(&params.mPtrExpertCounts[threadIdx.x], localExpertCount);
979+
#else
980+
assert(false && "routingIndicesHistogramKernel is only supported on SM90+ architectures");
976981
#endif
977982
}
978983

@@ -1204,6 +1209,8 @@ __global__ void __launch_bounds__(NumThreads) routingIndicesOffsetsKernel(Kernel
12041209
{
12051210
cudaTriggerProgrammaticLaunchCompletion();
12061211
}
1212+
#else
1213+
assert(false && "routingIndicesOffsetsKernel is only supported on SM90+ architectures");
12071214
#endif
12081215
}
12091216

docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ cat >./extra-llm-api-config.yml<<EOF
110110
pytorch_backend_config:
111111
enable_overlap_scheduler: true
112112
use_cuda_graph: true
113+
moe_backend: TRTLLM
113114
speculative_config:
114115
decoding_type: MTP
115116
num_nextn_predict_layers: 3
@@ -125,7 +126,7 @@ trtllm-bench --model nvidia/DeepSeek-R1-FP4 \
125126
--concurrency 1 \
126127
--max_batch_size 1 \
127128
--tp 8 \
128-
--ep 4 \
129+
--ep 2 \
129130
--extra_llm_api_options ./extra-llm-api-config.yml
130131
```
131132

@@ -147,12 +148,13 @@ The perf can be different when using different datasets and different machines.
147148
===========================================================
148149
= PERFORMANCE OVERVIEW
149150
===========================================================
150-
Request Throughput (req/sec): 0.1244
151-
Total Output Throughput (tokens/sec): 254.5535
152-
Per User Output Throughput (tokens/sec/user): 254.7634
153-
Per GPU Output Throughput (tokens/sec/gpu): 31.8192
154-
Total Latency (ms): 80368.1616
155-
Average request latency (ms): 8036.7546
151+
Request Throughput (req/sec): 0.1341
152+
Total Output Throughput (tokens/sec): 274.4168
153+
Per User Output Throughput (tokens/sec/user): 274.7188
154+
Per GPU Output Throughput (tokens/sec/gpu): 34.3021
155+
Total Token Throughput (tokens/sec): 414.0461
156+
Total Latency (ms): 74561.7520
157+
Average request latency (ms): 7456.1219
156158
```
157159

158160
### B200 max-throughput

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,9 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
624624
)
625625
return mlp_tp_size
626626

627+
def _enable_latency_mode(self, num_tokens: int):
628+
return num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION and self.is_nvfp4 and self.model_config.moe_backend == 'CUTLASS'
629+
627630
def forward(
628631
self,
629632
position_ids: torch.LongTensor,
@@ -650,9 +653,7 @@ def forward(
650653
using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size(
651654
0) > 128
652655

653-
min_latency_mode = True if hidden_states.size(
654-
0
655-
) <= 128 and self.fusion_config.POST_MOE_FUSION and self.is_nvfp4 and self.model_config.moe_backend == 'CUTLASS' else False
656+
min_latency_mode = self._enable_latency_mode(hidden_states.size(0))
656657

657658
if self.fusion_config.PRE_MOE_FUSION:
658659
# Custom AR Fusion for DeepseekV3

0 commit comments

Comments
 (0)