Skip to content

Commit 07bf6c7

Browse files
committed
Ruff and mypy fixes
1 parent 2baf7ed commit 07bf6c7

File tree

4 files changed

+67
-37
lines changed

4 files changed

+67
-37
lines changed

vec_inf/cli/_helper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ def format_table_output(self) -> Table:
188188
"Memory/Node", f" {self.params['models'][model_name]['mem_per_node']}"
189189
)
190190
table.add_row(
191-
"Inference Engine", f" {ENGINE_NAME_MAP[self.params['models'][model_name]['engine']]}"
191+
"Inference Engine",
192+
f" {ENGINE_NAME_MAP[self.params['models'][model_name]['engine']]}",
192193
)
193194

194195
return table
@@ -483,7 +484,7 @@ def _format_single_model_output(self, config: ModelConfig) -> Union[str, Table]:
483484
config_dict["model_weights_parent_dir"]
484485
)
485486
return json.dumps(config_dict, indent=4)
486-
487+
487488
excluded_list = ["venv", "log_dir"]
488489

489490
table = create_table(key_title="Model Config", value_title="Value")

vec_inf/client/_helper.py

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,20 @@ def _process_env_vars(self, env_arg: str) -> dict[str, str]:
203203
else:
204204
print(f"WARNING: Could not parse env var: {line}")
205205
return env_vars
206-
206+
207207
def _engine_check_override(self, params: dict[str, Any]) -> None:
208208
"""Check for engine override in CLI args and warn user.
209209
210210
Parameters
211211
----------
212212
params : dict[str, Any]
213213
Dictionary of launch parameters to check
214-
"""
214+
"""
215+
215216
def overwrite_engine_args(params: dict[str, Any]) -> None:
216-
engine_args = self._process_engine_args(self.kwargs[f"{self.engine}_args"], self.engine)
217+
engine_args = self._process_engine_args(
218+
self.kwargs[f"{self.engine}_args"], self.engine
219+
)
217220
for key, value in engine_args.items():
218221
params["engine_args"][key] = value
219222
del self.kwargs[f"{self.engine}_args"]
@@ -236,10 +239,9 @@ def overwrite_engine_args(params: dict[str, Any]) -> None:
236239
raise ValueError(
237240
f"Mismatch between provided engine '{input_engine}' and engine-specific args '{extracted_engine}'"
238241
)
239-
else:
240-
self.engine = input_engine
241-
params["engine_args"] = params[f"{self.engine}_args"]
242-
overwrite_engine_args(params)
242+
self.engine = input_engine
243+
params["engine_args"] = params[f"{self.engine}_args"]
244+
overwrite_engine_args(params)
243245
elif input_engine:
244246
# Only engine arg in CLI, use default engine args from config
245247
self.engine = input_engine
@@ -255,8 +257,7 @@ def overwrite_engine_args(params: dict[str, Any]) -> None:
255257
self.engine = params.get("engine", "vllm")
256258
params["engine_args"] = params[f"{self.engine}_args"]
257259

258-
# Remove $ENGINE_NAME_args from params as we no longer need them, and they don't get
259-
# populated to the job json.
260+
# Remove $ENGINE_NAME_args from params as they won't get populated to sjob json.
260261
for engine in SUPPORTED_ENGINES:
261262
del params[f"{engine}_args"]
262263

@@ -267,9 +268,9 @@ def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
267268
----------
268269
params : dict[str, Any]
269270
Dictionary of launch parameters to override
270-
"""
271+
"""
271272
self._engine_check_override(params)
272-
273+
273274
if self.kwargs.get("env"):
274275
env_vars = self._process_env_vars(self.kwargs["env"])
275276
for key, value in env_vars.items():
@@ -513,6 +514,53 @@ def _get_model_configurations(self) -> dict[str, ModelConfig]:
513514

514515
return model_configs_dict
515516

517+
def _validate_resource_and_parallel_settings(
518+
self,
519+
config: ModelConfig,
520+
model_engine_args: dict[str, Any] | None,
521+
model_name: str,
522+
) -> None:
523+
"""Validate resource allocation and parallelization settings for each model.
524+
525+
Parameters
526+
----------
527+
config : ModelConfig
528+
Configuration of the model to validate
529+
model_engine_args : dict[str, Any] | None
530+
Inference engine arguments of the model to validate
531+
model_name : str
532+
Name of the model to validate
533+
534+
Raises
535+
------
536+
MissingRequiredFieldsError
537+
If tensor parallel size is not specified when using multiple GPUs
538+
ValueError
539+
If total # of GPUs requested is not a power of two
540+
If mismatch between total # of GPUs requested and parallelization settings
541+
"""
542+
if (
543+
int(config.gpus_per_node) > 1
544+
and (model_engine_args or {}).get("--tensor-parallel-size") is None
545+
):
546+
raise MissingRequiredFieldsError(
547+
f"--tensor-parallel-size is required when gpus_per_node > 1, check your configuration for {model_name}"
548+
)
549+
550+
total_gpus_requested = int(config.gpus_per_node) * int(config.num_nodes)
551+
if not utils.is_power_of_two(total_gpus_requested):
552+
raise ValueError(
553+
f"Total number of GPUs requested must be a power of two, check your configuration for {model_name}"
554+
)
555+
556+
total_parallel_sizes = int(
557+
(model_engine_args or {}).get("--tensor-parallel-size", "1")
558+
) * int((model_engine_args or {}).get("--pipeline-parallel-size", "1"))
559+
if total_gpus_requested != total_parallel_sizes:
560+
raise ValueError(
561+
f"Mismatch between total number of GPUs requested and parallelization settings, check your configuration for {model_name}"
562+
)
563+
516564
def _get_launch_params(
517565
self, account: Optional[str] = None, work_dir: Optional[str] = None
518566
) -> dict[str, Any]:
@@ -549,27 +597,9 @@ def _get_launch_params(
549597
del params["models"][model_name][f"{engine}_args"]
550598

551599
# Validate resource allocation and parallelization settings
552-
if (
553-
int(config.gpus_per_node) > 1
554-
and (model_engine_args or {}).get("--tensor-parallel-size") is None
555-
):
556-
raise MissingRequiredFieldsError(
557-
f"--tensor-parallel-size is required when gpus_per_node > 1, check your configuration for {model_name}"
558-
)
559-
560-
total_gpus_requested = int(config.gpus_per_node) * int(config.num_nodes)
561-
if not utils.is_power_of_two(total_gpus_requested):
562-
raise ValueError(
563-
f"Total number of GPUs requested must be a power of two, check your configuration for {model_name}"
564-
)
565-
566-
total_parallel_sizes = int(
567-
(model_engine_args or {}).get("--tensor-parallel-size", "1")
568-
) * int((model_engine_args or {}).get("--pipeline-parallel-size", "1"))
569-
if total_gpus_requested != total_parallel_sizes:
570-
raise ValueError(
571-
f"Mismatch between total number of GPUs requested and parallelization settings, check your configuration for {model_name}"
572-
)
600+
self._validate_resource_and_parallel_settings(
601+
config, model_engine_args, model_name
602+
)
573603

574604
# Convert gpus_per_node and resource_type to gres
575605
params["models"][model_name]["gres"] = (

vec_inf/client/_slurm_templates.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ class SlurmScriptTemplate(TypedDict):
175175
'nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")',
176176
"nodes_array=($nodes)",
177177
"head_node=${nodes_array[0]}",
178-
# 'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)',
179178
"NCCL_PORT=$(find_available_port $head_node 8000 65535)",
180179
'NCCL_INIT_ADDR="${head_node}:${NCCL_PORT}"',
181180
'echo "[INFO] NCCL_INIT_ADDR: $NCCL_INIT_ADDR"',

vec_inf/client/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def read_slurm_log(
7777
json_content: dict[str, str] = json.load(file)
7878
return json_content
7979
else:
80-
with file_path.open("r", errors='replace') as file:
80+
with file_path.open("r", errors="replace") as file:
8181
return file.readlines()
8282
except FileNotFoundError:
8383
return f"LOG FILE NOT FOUND: {file_path}"
@@ -249,7 +249,7 @@ def load_config(config_path: Optional[str] = None) -> list[ModelConfig]:
249249
-----
250250
Configuration is loaded from:
251251
1. User path: specified by config_path
252-
2. Default path: package's config/models.yaml or CACHED_MODEL_CONFIG_PATH if it exists
252+
2. Default path: package's config/models.yaml or CACHED_MODEL_CONFIG_PATH if exists
253253
3. Environment variable: specified by VEC_INF_CONFIG environment variable
254254
and merged with default config
255255

0 commit comments

Comments
 (0)