77import pytest
88import requests
99
10+ from vec_inf .client ._client_vars import SRC_DIR
1011from vec_inf .client ._exceptions import (
1112 MissingRequiredFieldsError ,
1213 ModelConfigurationError ,
1314 ModelNotFoundError ,
1415 SlurmJobError ,
1516)
1617from vec_inf .client ._helper import (
18+ BatchModelLauncher ,
1719 ModelLauncher ,
1820 ModelRegistry ,
1921 ModelStatusMonitor ,
2022 PerformanceMetricsCollector ,
21- BatchModelLauncher ,
2223)
2324from vec_inf .client .config import ModelConfig
2425from vec_inf .client .models import (
2526 ModelStatus ,
2627 ModelType ,
2728 StatusResponse ,
2829)
29- from vec_inf .client ._client_vars import SRC_DIR
3030
3131
3232class TestModelLauncher :
@@ -197,6 +197,7 @@ def test_launch_with_slurm_error(
197197 with pytest .raises (SlurmJobError ):
198198 launcher .launch ()
199199
200+
200201class TestBatchModelLauncher :
201202 """Tests for the BatchModelLauncher class."""
202203
@@ -252,7 +253,9 @@ def test_init_with_valid_configs(self, mock_load_config, batch_model_configs):
252253 assert "family2-variant1" in launcher .model_configs
253254
254255 @patch ("vec_inf.client._helper.utils.load_config" )
255- def test_init_with_missing_model_config (self , mock_load_config , batch_model_configs ):
256+ def test_init_with_missing_model_config (
257+ self , mock_load_config , batch_model_configs
258+ ):
256259 """Test error is raised when one of the models is missing from config."""
257260 mock_load_config .return_value = batch_model_configs
258261
@@ -266,9 +269,14 @@ def test_init_with_missing_model_config(self, mock_load_config, batch_model_conf
266269 def test_get_slurm_job_name (self , mock_load_config , batch_model_configs ):
267270 """Test SLURM job name is constructed correctly from model names."""
268271 mock_load_config .return_value = batch_model_configs
269- launcher = BatchModelLauncher (["family1-variant1" , "family2-variant1" , "family1-variant2" ])
272+ launcher = BatchModelLauncher (
273+ ["family1-variant1" , "family2-variant1" , "family1-variant2" ]
274+ )
270275
271- assert launcher .slurm_job_name == "BATCH-family1-variant1-family2-variant1-family1-variant2"
276+ assert (
277+ launcher .slurm_job_name
278+ == "BATCH-family1-variant1-family2-variant1-family1-variant2"
279+ )
272280
273281 @patch ("vec_inf.client._helper.utils.load_config" )
274282 @patch ("pathlib.Path.mkdir" )
@@ -278,14 +286,19 @@ def test_get_launch_params_creates_log_dirs(
278286 """Test launch parameters preparation creates log directories."""
279287 mock_load_config .return_value = batch_model_configs
280288
281- launcher = BatchModelLauncher (["family1-variant1" , "family2-variant1" , "family1-variant2" ])
289+ launcher = BatchModelLauncher (
290+ ["family1-variant1" , "family2-variant1" , "family1-variant2" ]
291+ )
282292 params = launcher .params
283293
284294 assert "models" in params
285295 assert "family1-variant1" in params ["models" ]
286296 assert "family2-variant1" in params ["models" ]
287297 assert "family1-variant2" in params ["models" ]
288- assert params ["slurm_job_name" ] == "BATCH-family1-variant1-family2-variant1-family1-variant2"
298+ assert (
299+ params ["slurm_job_name" ]
300+ == "BATCH-family1-variant1-family2-variant1-family1-variant2"
301+ )
289302 assert params ["src_dir" ] == str (SRC_DIR )
290303
291304 # Check that log directories are created
@@ -318,7 +331,7 @@ def test_get_launch_params_with_non_power_of_two_gpus(
318331 batch_model_configs [0 ].model_copy (
319332 update = {
320333 "gpus_per_node" : 3 ,
321- "vllm_args" : {"--tensor-parallel-size" : "3" }
334+ "vllm_args" : {"--tensor-parallel-size" : "3" },
322335 }
323336 ),
324337 batch_model_configs [1 ],
@@ -345,7 +358,7 @@ def test_get_launch_params_with_mismatched_batch_args(
345358 update = {
346359 "gpus_per_node" : 1 ,
347360 "num_nodes" : 2 , # This will cause the mismatch
348- "vllm_args" : {"--tensor-parallel-size" : "1" }
361+ "vllm_args" : {"--tensor-parallel-size" : "1" },
349362 }
350363 ),
351364 ]
@@ -450,13 +463,31 @@ def test_launch_params_log_file_paths(self, mock_load_config, batch_model_config
450463 params = launcher .params
451464
452465 # Check individual model log files
453- assert "family1-variant1.%j.out" in params ["models" ]["family1-variant1" ]["out_file" ]
454- assert "family1-variant1.%j.err" in params ["models" ]["family1-variant1" ]["err_file" ]
455- assert "family1-variant1.$SLURM_JOB_ID.json" in params ["models" ]["family1-variant1" ]["json_file" ]
466+ assert (
467+ "family1-variant1.%j.out"
468+ in params ["models" ]["family1-variant1" ]["out_file" ]
469+ )
470+ assert (
471+ "family1-variant1.%j.err"
472+ in params ["models" ]["family1-variant1" ]["err_file" ]
473+ )
474+ assert (
475+ "family1-variant1.$SLURM_JOB_ID.json"
476+ in params ["models" ]["family1-variant1" ]["json_file" ]
477+ )
456478
457- assert "family2-variant1.%j.out" in params ["models" ]["family2-variant1" ]["out_file" ]
458- assert "family2-variant1.%j.err" in params ["models" ]["family2-variant1" ]["err_file" ]
459- assert "family2-variant1.$SLURM_JOB_ID.json" in params ["models" ]["family2-variant1" ]["json_file" ]
479+ assert (
480+ "family2-variant1.%j.out"
481+ in params ["models" ]["family2-variant1" ]["out_file" ]
482+ )
483+ assert (
484+ "family2-variant1.%j.err"
485+ in params ["models" ]["family2-variant1" ]["err_file" ]
486+ )
487+ assert (
488+ "family2-variant1.$SLURM_JOB_ID.json"
489+ in params ["models" ]["family2-variant1" ]["json_file" ]
490+ )
460491
461492 # Check batch-level log files
462493 assert "BATCH-family1-variant1-family2-variant1.%j.out" in params ["out_file" ]
@@ -467,7 +498,9 @@ def test_init_with_batch_config(self, mock_load_config, batch_model_configs):
467498 """Test launcher initializes correctly with custom batch config."""
468499 mock_load_config .return_value = batch_model_configs
469500
470- launcher = BatchModelLauncher (["family1-variant1" , "family2-variant1" ], batch_config = "custom_config.yaml" )
501+ launcher = BatchModelLauncher (
502+ ["family1-variant1" , "family2-variant1" ], batch_config = "custom_config.yaml"
503+ )
471504
472505 assert launcher .batch_config == "custom_config.yaml"
473506 # Verify load_config was called with the custom config
@@ -855,4 +888,4 @@ def test_get_single_model_config_not_found(self, mock_load_config, mock_configs)
855888 registry = ModelRegistry ()
856889
857890 with pytest .raises (ModelNotFoundError ):
858- registry .get_single_model_config ("nonexistent_model" )
891+ registry .get_single_model_config ("nonexistent_model" )
0 commit comments