Skip to content

Commit 523a17d

Browse files
lfr-0531litaotju
andauthored
[https://nvbugs/5485325][fix] Cherry-pick #7373: fix the CUDA graph warmup issue when using speculative decoding (#7734)
Signed-off-by: Fanrong Li <[email protected]> Co-authored-by: Tao Li @ NVIDIA <[email protected]>
1 parent 3924832 commit 523a17d

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,42 @@ def test_nvfp4_multi_gpus_corner_case(self):
20402040
task = GSM8K(self.MODEL_NAME)
20412041
task.evaluate(llm)
20422042

2043+
def test_nvfp4_multi_gpus_corner_case(self):
2044+
"""
2045+
This test is used to test the corner case of the NVFP4 model.
2046+
When using the same value for max_seq_len and max_num_tokens, there will be no
2047+
enough kv block for the dummy requests in CUDA graph warmup when creating
2048+
the py_executor before estimating kv cache. Then CUDA graph capture will be
2049+
triggered when estimating kv cache. This may cause some errors.
2050+
More info in https://nvbugs/5485325.
2051+
"""
2052+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80,
2053+
dtype="fp8",
2054+
enable_block_reuse=False)
2055+
pytorch_config = dict(disable_overlap_scheduler=False,
2056+
cuda_graph_config=CudaGraphConfig(
2057+
enable_padding=True, max_batch_size=1024),
2058+
moe_config=MoeConfig(backend="TRTLLM"))
2059+
2060+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1)
2061+
with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4",
2062+
tensor_parallel_size=8,
2063+
pipeline_parallel_size=1,
2064+
moe_expert_parallel_size=8,
2065+
kv_cache_config=kv_cache_config,
2066+
**pytorch_config,
2067+
enable_attention_dp=False,
2068+
speculative_config=mtp_config,
2069+
max_seq_len=5120,
2070+
max_num_tokens=5120) as llm:
2071+
2072+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
2073+
2074+
task = MMLU(self.MODEL_NAME)
2075+
task.evaluate(llm)
2076+
task = GSM8K(self.MODEL_NAME)
2077+
task.evaluate(llm)
2078+
20432079
@pytest.mark.skip_less_mpi_world_size(8)
20442080
@skip_pre_hopper
20452081
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)