Skip to content

Commit 68b7900

Browse files
authored
[https://nvbugs/5531963][fix] cherry pick #7725 (#7907)
Signed-off-by: junq <[email protected]>
1 parent bc4136f commit 68b7900

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tensorrt_llm/bench/dataclasses/configuration.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,16 @@ def get_llm_args(self) -> Dict:
9090
if self.backend == "pytorch":
9191
cuda_graph_config = updated_llm_args.pop(
9292
"cuda_graph_config", llm_args["cuda_graph_config"])
93-
# Use runtime max_batch_size as cuda_graph_config.max_batch_size
94-
# if both max_batch_size and batch_sizes are not set.
95-
batch_sizes_set = cuda_graph_config.get("batch_sizes",
96-
None) is not None
97-
max_batch_size_set = cuda_graph_config.get("max_batch_size",
98-
None) is not None
99-
if not batch_sizes_set and not max_batch_size_set:
100-
cuda_graph_config[
101-
"max_batch_size"] = self.settings_config.max_batch_size
93+
if cuda_graph_config:
94+
# Use runtime max_batch_size as cuda_graph_config.max_batch_size
95+
# if both max_batch_size and batch_sizes are not set.
96+
batch_sizes_set = cuda_graph_config.get("batch_sizes",
97+
None) is not None
98+
max_batch_size_set = cuda_graph_config.get(
99+
"max_batch_size", None) is not None
100+
if not batch_sizes_set and not max_batch_size_set:
101+
cuda_graph_config[
102+
"max_batch_size"] = self.settings_config.max_batch_size
102103
updated_llm_args["cuda_graph_config"] = cuda_graph_config
103104

104105
return updated_llm_args

0 commit comments

Comments
 (0)