Skip to content

Commit d61893d

Browse files
authored
[fix] Update to properly set cuda graphs in trtllm-bench overrides. (NVIDIA#5634)
Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
1 parent d1112aa commit d61893d

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

tensorrt_llm/bench/benchmark/utils/general.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,15 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
8686
enable_chunked_prefill = params.get("enable_chunked_prefill", False)
8787

8888
kv_cache_dtype = "auto"
89-
cuda_graph_batch_sizes = None
9089
if extra_llm_api_options:
9190
with open(extra_llm_api_options, 'r') as f:
9291
llm_args_dict = yaml.safe_load(f)
93-
if "kv_cache_dtype" in llm_args_dict:
94-
kv_cache_dtype = llm_args_dict["kv_cache_dtype"]
95-
if "cuda_graph_batch_sizes" in llm_args_dict:
96-
cuda_graph_batch_sizes = llm_args_dict["cuda_graph_batch_sizes"]
9792

98-
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
99-
enable_chunked_prefill)
93+
if "kv_cache_dtype" in llm_args_dict:
94+
kv_cache_dtype = llm_args_dict["kv_cache_dtype"]
95+
96+
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
97+
enable_chunked_prefill)
10098

10199
world_config = {
102100
"pp_size": params.get("pp"),
@@ -152,17 +150,17 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
152150
# Expecting this to be the max of chunk block and max_num_tokens.
153151
pass
154152

153+
cuda_graph_config = {
154+
"padding_enabled": True,
155+
"max_batch_size": max_batch_size
156+
}
157+
155158
pyt_options = {
156-
"cuda_graph_config": {
157-
"padding_enabled":
158-
True,
159-
"max_batch_size":
160-
max_batch_size if cuda_graph_batch_sizes is None else 0,
161-
},
159+
"cuda_graph_config": cuda_graph_config,
162160
"kv_cache_dtype": kv_cache_dtype,
163161
}
164-
backend = params.get("backend", "pytorch")
165162

163+
backend = params.get("backend", "pytorch")
166164
return {
167165
"sw_version": version("tensorrt_llm"),
168166
"model_path": model_path,

0 commit comments

Comments
 (0)