@@ -203,17 +203,20 @@ def _process_env_vars(self, env_arg: str) -> dict[str, str]:
203203 else :
204204 print (f"WARNING: Could not parse env var: { line } " )
205205 return env_vars
206-
206+
207207 def _engine_check_override (self , params : dict [str , Any ]) -> None :
208208 """Check for engine override in CLI args and warn user.
209209
210210 Parameters
211211 ----------
212212 params : dict[str, Any]
213213 Dictionary of launch parameters to check
214- """
214+ """
215+
215216 def overwrite_engine_args (params : dict [str , Any ]) -> None :
216- engine_args = self ._process_engine_args (self .kwargs [f"{ self .engine } _args" ], self .engine )
217+ engine_args = self ._process_engine_args (
218+ self .kwargs [f"{ self .engine } _args" ], self .engine
219+ )
217220 for key , value in engine_args .items ():
218221 params ["engine_args" ][key ] = value
219222 del self .kwargs [f"{ self .engine } _args" ]
@@ -236,10 +239,9 @@ def overwrite_engine_args(params: dict[str, Any]) -> None:
236239 raise ValueError (
237240 f"Mismatch between provided engine '{ input_engine } ' and engine-specific args '{ extracted_engine } '"
238241 )
239- else :
240- self .engine = input_engine
241- params ["engine_args" ] = params [f"{ self .engine } _args" ]
242- overwrite_engine_args (params )
242+ self .engine = input_engine
243+ params ["engine_args" ] = params [f"{ self .engine } _args" ]
244+ overwrite_engine_args (params )
243245 elif input_engine :
244246 # Only engine arg in CLI, use default engine args from config
245247 self .engine = input_engine
@@ -255,8 +257,7 @@ def overwrite_engine_args(params: dict[str, Any]) -> None:
255257 self .engine = params .get ("engine" , "vllm" )
256258 params ["engine_args" ] = params [f"{ self .engine } _args" ]
257259
258- # Remove $ENGINE_NAME_args from params as we no longer need them, and they don't get
259- # populated to the job json.
260+ # Remove $ENGINE_NAME_args from params as they won't get populated to sjob json.
260261 for engine in SUPPORTED_ENGINES :
261262 del params [f"{ engine } _args" ]
262263
@@ -267,9 +268,9 @@ def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
267268 ----------
268269 params : dict[str, Any]
269270 Dictionary of launch parameters to override
270- """
271+ """
271272 self ._engine_check_override (params )
272-
273+
273274 if self .kwargs .get ("env" ):
274275 env_vars = self ._process_env_vars (self .kwargs ["env" ])
275276 for key , value in env_vars .items ():
@@ -513,6 +514,53 @@ def _get_model_configurations(self) -> dict[str, ModelConfig]:
513514
514515 return model_configs_dict
515516
517+ def _validate_resource_and_parallel_settings (
518+ self ,
519+ config : ModelConfig ,
520+ model_engine_args : dict [str , Any ] | None ,
521+ model_name : str ,
522+ ) -> None :
523+ """Validate resource allocation and parallelization settings for each model.
524+
525+ Parameters
526+ ----------
527+ config : ModelConfig
528+ Configuration of the model to validate
529+ model_engine_args : dict[str, Any] | None
530+ Inference engine arguments of the model to validate
531+ model_name : str
532+ Name of the model to validate
533+
534+ Raises
535+ ------
536+ MissingRequiredFieldsError
537+ If tensor parallel size is not specified when using multiple GPUs
538+ ValueError
539+ If total # of GPUs requested is not a power of two
540+ If mismatch between total # of GPUs requested and parallelization settings
541+ """
542+ if (
543+ int (config .gpus_per_node ) > 1
544+ and (model_engine_args or {}).get ("--tensor-parallel-size" ) is None
545+ ):
546+ raise MissingRequiredFieldsError (
547+ f"--tensor-parallel-size is required when gpus_per_node > 1, check your configuration for { model_name } "
548+ )
549+
550+ total_gpus_requested = int (config .gpus_per_node ) * int (config .num_nodes )
551+ if not utils .is_power_of_two (total_gpus_requested ):
552+ raise ValueError (
553+ f"Total number of GPUs requested must be a power of two, check your configuration for { model_name } "
554+ )
555+
556+ total_parallel_sizes = int (
557+ (model_engine_args or {}).get ("--tensor-parallel-size" , "1" )
558+ ) * int ((model_engine_args or {}).get ("--pipeline-parallel-size" , "1" ))
559+ if total_gpus_requested != total_parallel_sizes :
560+ raise ValueError (
561+ f"Mismatch between total number of GPUs requested and parallelization settings, check your configuration for { model_name } "
562+ )
563+
516564 def _get_launch_params (
517565 self , account : Optional [str ] = None , work_dir : Optional [str ] = None
518566 ) -> dict [str , Any ]:
@@ -549,27 +597,9 @@ def _get_launch_params(
549597 del params ["models" ][model_name ][f"{ engine } _args" ]
550598
551599 # Validate resource allocation and parallelization settings
552- if (
553- int (config .gpus_per_node ) > 1
554- and (model_engine_args or {}).get ("--tensor-parallel-size" ) is None
555- ):
556- raise MissingRequiredFieldsError (
557- f"--tensor-parallel-size is required when gpus_per_node > 1, check your configuration for { model_name } "
558- )
559-
560- total_gpus_requested = int (config .gpus_per_node ) * int (config .num_nodes )
561- if not utils .is_power_of_two (total_gpus_requested ):
562- raise ValueError (
563- f"Total number of GPUs requested must be a power of two, check your configuration for { model_name } "
564- )
565-
566- total_parallel_sizes = int (
567- (model_engine_args or {}).get ("--tensor-parallel-size" , "1" )
568- ) * int ((model_engine_args or {}).get ("--pipeline-parallel-size" , "1" ))
569- if total_gpus_requested != total_parallel_sizes :
570- raise ValueError (
571- f"Mismatch between total number of GPUs requested and parallelization settings, check your configuration for { model_name } "
572- )
600+ self ._validate_resource_and_parallel_settings (
601+ config , model_engine_args , model_name
602+ )
573603
574604 # Convert gpus_per_node and resource_type to gres
575605 params ["models" ][model_name ]["gres" ] = (
0 commit comments