@@ -26,10 +26,10 @@ def parse_args():
2626 '--dir' ,
2727 type = str ,
2828 help = 'Directory containing YAML configuration files' )
29- group .add_argument ('--log-dir' ,
30- type = str ,
31- default = None ,
32- help = 'Log directory' )
29+ parser .add_argument ('--log-dir' ,
30+ type = str ,
31+ default = None ,
32+ help = 'Log directory' )
3333 return parser .parse_args ()
3434
3535
@@ -154,16 +154,20 @@ def submit_job(config, log_dir):
154154 {}).get ('num_nextn_predict_layers' , 0 )
155155
156156 # Calculate nodes based on world sizes
157- ctx_tp_size = config ['worker_config' ]['ctx' ]['tensor_parallel_size' ]
158- ctx_cp_size = config ['worker_config' ]['ctx' ]['context_parallel_size' ]
159- ctx_pp_size = config ['worker_config' ]['ctx' ]['pipeline_parallel_size' ]
157+ ctx_tp_size = config ['worker_config' ]['ctx' ].get ('tensor_parallel_size' , 1 )
158+ ctx_cp_size = config ['worker_config' ]['ctx' ].get ('context_parallel_size' , 1 )
159+ ctx_pp_size = config ['worker_config' ]['ctx' ].get ('pipeline_parallel_size' ,
160+ 1 )
160161 ctx_world_size = ctx_tp_size * ctx_cp_size * ctx_pp_size
161162 ctx_nodes = calculate_nodes (ctx_world_size , ctx_num , gpus_per_node )
162- gen_tp_size = config ['worker_config' ]['gen' ]['tensor_parallel_size' ]
163- gen_cp_size = config ['worker_config' ]['gen' ]['context_parallel_size' ]
164- gen_pp_size = config ['worker_config' ]['gen' ]['pipeline_parallel_size' ]
163+
164+ gen_tp_size = config ['worker_config' ]['gen' ].get ('tensor_parallel_size' , 1 )
165+ gen_cp_size = config ['worker_config' ]['gen' ].get ('context_parallel_size' , 1 )
166+ gen_pp_size = config ['worker_config' ]['gen' ].get ('pipeline_parallel_size' ,
167+ 1 )
165168 gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size
166169 gen_nodes = calculate_nodes (gen_world_size , gen_num , gpus_per_node )
170+
167171 total_nodes = ctx_nodes + gen_nodes
168172 total_tasks = total_nodes * gpus_per_node
169173
@@ -259,7 +263,7 @@ def submit_job(config, log_dir):
259263 str (allocation ["port" ]),
260264 config ['benchmark' ]['mode' ],
261265 config ['benchmark' ]['concurrency_list' ],
262- str (slurm_config ['numa_bind' ]),
266+ str (slurm_config ['numa_bind' ]). lower () ,
263267 log_dir ,
264268 str (profiling_config ['nsys_on' ]).lower (),
265269 profiling_config ['gen_profile_range' ]
@@ -303,6 +307,7 @@ def submit_job(config, log_dir):
303307 '--benchmark-ratio' , str (config ['benchmark' ]['benchmark_ratio' ]),
304308 '--streaming' , str (config ['benchmark' ]['streaming' ]).lower (),
305309 '--use-nv-sa-benchmark' , str (config ['benchmark' ]['use_nv_sa_benchmark' ]).lower (),
310+ '--benchmark-mode' , config ['benchmark' ]['mode' ],
306311
307312 # Environment and paths
308313 '--dataset-file' , config ['benchmark' ]['dataset_file' ],
0 commit comments