File tree Expand file tree Collapse file tree 1 file changed +10
-9
lines changed
tensorrt_llm/bench/dataclasses Expand file tree Collapse file tree 1 file changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -90,15 +90,16 @@ def get_llm_args(self) -> Dict:
9090 if self .backend == "pytorch" :
9191 cuda_graph_config = updated_llm_args .pop (
9292 "cuda_graph_config" , llm_args ["cuda_graph_config" ])
93- # Use runtime max_batch_size as cuda_graph_config.max_batch_size
94- # if both max_batch_size and batch_sizes are not set.
95- batch_sizes_set = cuda_graph_config .get ("batch_sizes" ,
96- None ) is not None
97- max_batch_size_set = cuda_graph_config .get ("max_batch_size" ,
98- None ) is not None
99- if not batch_sizes_set and not max_batch_size_set :
100- cuda_graph_config [
101- "max_batch_size" ] = self .settings_config .max_batch_size
93+ if cuda_graph_config :
94+ # Use runtime max_batch_size as cuda_graph_config.max_batch_size
95+ # if both max_batch_size and batch_sizes are not set.
96+ batch_sizes_set = cuda_graph_config .get ("batch_sizes" ,
97+ None ) is not None
98+ max_batch_size_set = cuda_graph_config .get (
99+ "max_batch_size" , None ) is not None
100+ if not batch_sizes_set and not max_batch_size_set :
101+ cuda_graph_config [
102+ "max_batch_size" ] = self .settings_config .max_batch_size
102103 updated_llm_args ["cuda_graph_config" ] = cuda_graph_config
103104
104105 return updated_llm_args
You can’t perform that action at this time.
0 commit comments