Skip to content

Commit 1e8b61a

Browse files
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 70da2db commit 1e8b61a

File tree

3 files changed

+125
-44
lines changed

3 files changed

+125
-44
lines changed

tests/vec_inf/client/test_api.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vec_inf.client._exceptions import (
99
ModelConfigurationError,
1010
ServerError,
11-
SlurmJobError
11+
SlurmJobError,
1212
)
1313

1414

@@ -324,8 +324,7 @@ def test_batch_launch_models_with_config():
324324
client.batch_launch_models = lambda model_names, batch_config=None: mock_response
325325

326326
result = client.batch_launch_models(
327-
["model1", "model2"],
328-
batch_config="custom_config.yaml"
327+
["model1", "model2"], batch_config="custom_config.yaml"
329328
)
330329

331330
assert result.slurm_job_id == "12345678"
@@ -420,11 +419,16 @@ def test_batch_launch_models_configuration_error():
420419

421420
# Mock the batch launch method to raise a configuration error
422421
def mock_batch_launch(model_names, batch_config=None):
423-
raise ModelConfigurationError("Model 'nonexistent-model' not found in configuration")
422+
raise ModelConfigurationError(
423+
"Model 'nonexistent-model' not found in configuration"
424+
)
424425

425426
client.batch_launch_models = mock_batch_launch
426427

427-
with pytest.raises(ModelConfigurationError, match="Model 'nonexistent-model' not found in configuration"):
428+
with pytest.raises(
429+
ModelConfigurationError,
430+
match="Model 'nonexistent-model' not found in configuration",
431+
):
428432
client.batch_launch_models(["model1", "nonexistent-model"])
429433

430434

@@ -438,7 +442,9 @@ def mock_batch_launch(model_names, batch_config=None):
438442

439443
client.batch_launch_models = mock_batch_launch
440444

441-
with pytest.raises(SlurmJobError, match="sbatch: error: Invalid partition specified"):
445+
with pytest.raises(
446+
SlurmJobError, match="sbatch: error: Invalid partition specified"
447+
):
442448
client.batch_launch_models(["model1", "model2"])
443449

444450

@@ -448,15 +454,18 @@ def test_batch_launch_models_integration():
448454

449455
with (
450456
patch("vec_inf.client.api.BatchModelLauncher") as mock_launcher_class,
451-
patch("vec_inf.client.api.run_bash_command", return_value=("Submitted batch job 12345678", ""))
457+
patch(
458+
"vec_inf.client.api.run_bash_command",
459+
return_value=("Submitted batch job 12345678", ""),
460+
),
452461
):
453462
# Mock the BatchModelLauncher instance
454463
mock_launcher = MagicMock()
455464
mock_launcher.launch.return_value = MagicMock(
456465
slurm_job_id="12345678",
457466
slurm_job_name="BATCH-model1-model2",
458467
model_names=["model1", "model2"],
459-
config={"slurm_job_id": "12345678"}
468+
config={"slurm_job_id": "12345678"},
460469
)
461470
mock_launcher_class.return_value = mock_launcher
462471

@@ -478,25 +487,29 @@ def test_batch_launch_models_with_custom_config_integration():
478487

479488
with (
480489
patch("vec_inf.client.api.BatchModelLauncher") as mock_launcher_class,
481-
patch("vec_inf.client.api.run_bash_command", return_value=("Submitted batch job 12345678", ""))
490+
patch(
491+
"vec_inf.client.api.run_bash_command",
492+
return_value=("Submitted batch job 12345678", ""),
493+
),
482494
):
483495
# Mock the BatchModelLauncher instance
484496
mock_launcher = MagicMock()
485497
mock_launcher.launch.return_value = MagicMock(
486498
slurm_job_id="12345678",
487499
slurm_job_name="BATCH-model1-model2",
488500
model_names=["model1", "model2"],
489-
config={"slurm_job_id": "12345678"}
501+
config={"slurm_job_id": "12345678"},
490502
)
491503
mock_launcher_class.return_value = mock_launcher
492504

493505
result = client.batch_launch_models(
494-
["model1", "model2"],
495-
batch_config="custom_config.yaml"
506+
["model1", "model2"], batch_config="custom_config.yaml"
496507
)
497508

498509
# Verify BatchModelLauncher was called with custom config
499-
mock_launcher_class.assert_called_once_with(["model1", "model2"], "custom_config.yaml")
510+
mock_launcher_class.assert_called_once_with(
511+
["model1", "model2"], "custom_config.yaml"
512+
)
500513
mock_launcher.launch.assert_called_once()
501514

502515
# Verify the response

tests/vec_inf/client/test_helper.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,26 @@
77
import pytest
88
import requests
99

10+
from vec_inf.client._client_vars import SRC_DIR
1011
from vec_inf.client._exceptions import (
1112
MissingRequiredFieldsError,
1213
ModelConfigurationError,
1314
ModelNotFoundError,
1415
SlurmJobError,
1516
)
1617
from vec_inf.client._helper import (
18+
BatchModelLauncher,
1719
ModelLauncher,
1820
ModelRegistry,
1921
ModelStatusMonitor,
2022
PerformanceMetricsCollector,
21-
BatchModelLauncher,
2223
)
2324
from vec_inf.client.config import ModelConfig
2425
from vec_inf.client.models import (
2526
ModelStatus,
2627
ModelType,
2728
StatusResponse,
2829
)
29-
from vec_inf.client._client_vars import SRC_DIR
3030

3131

3232
class TestModelLauncher:
@@ -197,6 +197,7 @@ def test_launch_with_slurm_error(
197197
with pytest.raises(SlurmJobError):
198198
launcher.launch()
199199

200+
200201
class TestBatchModelLauncher:
201202
"""Tests for the BatchModelLauncher class."""
202203

@@ -252,7 +253,9 @@ def test_init_with_valid_configs(self, mock_load_config, batch_model_configs):
252253
assert "family2-variant1" in launcher.model_configs
253254

254255
@patch("vec_inf.client._helper.utils.load_config")
255-
def test_init_with_missing_model_config(self, mock_load_config, batch_model_configs):
256+
def test_init_with_missing_model_config(
257+
self, mock_load_config, batch_model_configs
258+
):
256259
"""Test error is raised when one of the models is missing from config."""
257260
mock_load_config.return_value = batch_model_configs
258261

@@ -266,9 +269,14 @@ def test_init_with_missing_model_config(self, mock_load_config, batch_model_conf
266269
def test_get_slurm_job_name(self, mock_load_config, batch_model_configs):
267270
"""Test SLURM job name is constructed correctly from model names."""
268271
mock_load_config.return_value = batch_model_configs
269-
launcher = BatchModelLauncher(["family1-variant1", "family2-variant1", "family1-variant2"])
272+
launcher = BatchModelLauncher(
273+
["family1-variant1", "family2-variant1", "family1-variant2"]
274+
)
270275

271-
assert launcher.slurm_job_name == "BATCH-family1-variant1-family2-variant1-family1-variant2"
276+
assert (
277+
launcher.slurm_job_name
278+
== "BATCH-family1-variant1-family2-variant1-family1-variant2"
279+
)
272280

273281
@patch("vec_inf.client._helper.utils.load_config")
274282
@patch("pathlib.Path.mkdir")
@@ -278,14 +286,19 @@ def test_get_launch_params_creates_log_dirs(
278286
"""Test launch parameters preparation creates log directories."""
279287
mock_load_config.return_value = batch_model_configs
280288

281-
launcher = BatchModelLauncher(["family1-variant1", "family2-variant1", "family1-variant2"])
289+
launcher = BatchModelLauncher(
290+
["family1-variant1", "family2-variant1", "family1-variant2"]
291+
)
282292
params = launcher.params
283293

284294
assert "models" in params
285295
assert "family1-variant1" in params["models"]
286296
assert "family2-variant1" in params["models"]
287297
assert "family1-variant2" in params["models"]
288-
assert params["slurm_job_name"] == "BATCH-family1-variant1-family2-variant1-family1-variant2"
298+
assert (
299+
params["slurm_job_name"]
300+
== "BATCH-family1-variant1-family2-variant1-family1-variant2"
301+
)
289302
assert params["src_dir"] == str(SRC_DIR)
290303

291304
# Check that log directories are created
@@ -318,7 +331,7 @@ def test_get_launch_params_with_non_power_of_two_gpus(
318331
batch_model_configs[0].model_copy(
319332
update={
320333
"gpus_per_node": 3,
321-
"vllm_args": {"--tensor-parallel-size": "3"}
334+
"vllm_args": {"--tensor-parallel-size": "3"},
322335
}
323336
),
324337
batch_model_configs[1],
@@ -345,7 +358,7 @@ def test_get_launch_params_with_mismatched_batch_args(
345358
update={
346359
"gpus_per_node": 1,
347360
"num_nodes": 2, # This will cause the mismatch
348-
"vllm_args": {"--tensor-parallel-size": "1"}
361+
"vllm_args": {"--tensor-parallel-size": "1"},
349362
}
350363
),
351364
]
@@ -450,13 +463,31 @@ def test_launch_params_log_file_paths(self, mock_load_config, batch_model_config
450463
params = launcher.params
451464

452465
# Check individual model log files
453-
assert "family1-variant1.%j.out" in params["models"]["family1-variant1"]["out_file"]
454-
assert "family1-variant1.%j.err" in params["models"]["family1-variant1"]["err_file"]
455-
assert "family1-variant1.$SLURM_JOB_ID.json" in params["models"]["family1-variant1"]["json_file"]
466+
assert (
467+
"family1-variant1.%j.out"
468+
in params["models"]["family1-variant1"]["out_file"]
469+
)
470+
assert (
471+
"family1-variant1.%j.err"
472+
in params["models"]["family1-variant1"]["err_file"]
473+
)
474+
assert (
475+
"family1-variant1.$SLURM_JOB_ID.json"
476+
in params["models"]["family1-variant1"]["json_file"]
477+
)
456478

457-
assert "family2-variant1.%j.out" in params["models"]["family2-variant1"]["out_file"]
458-
assert "family2-variant1.%j.err" in params["models"]["family2-variant1"]["err_file"]
459-
assert "family2-variant1.$SLURM_JOB_ID.json" in params["models"]["family2-variant1"]["json_file"]
479+
assert (
480+
"family2-variant1.%j.out"
481+
in params["models"]["family2-variant1"]["out_file"]
482+
)
483+
assert (
484+
"family2-variant1.%j.err"
485+
in params["models"]["family2-variant1"]["err_file"]
486+
)
487+
assert (
488+
"family2-variant1.$SLURM_JOB_ID.json"
489+
in params["models"]["family2-variant1"]["json_file"]
490+
)
460491

461492
# Check batch-level log files
462493
assert "BATCH-family1-variant1-family2-variant1.%j.out" in params["out_file"]
@@ -467,7 +498,9 @@ def test_init_with_batch_config(self, mock_load_config, batch_model_configs):
467498
"""Test launcher initializes correctly with custom batch config."""
468499
mock_load_config.return_value = batch_model_configs
469500

470-
launcher = BatchModelLauncher(["family1-variant1", "family2-variant1"], batch_config="custom_config.yaml")
501+
launcher = BatchModelLauncher(
502+
["family1-variant1", "family2-variant1"], batch_config="custom_config.yaml"
503+
)
471504

472505
assert launcher.batch_config == "custom_config.yaml"
473506
# Verify load_config was called with the custom config
@@ -855,4 +888,4 @@ def test_get_single_model_config_not_found(self, mock_load_config, mock_configs)
855888
registry = ModelRegistry()
856889

857890
with pytest.raises(ModelNotFoundError):
858-
registry.get_single_model_config("nonexistent_model")
891+
registry.get_single_model_config("nonexistent_model")

0 commit comments

Comments
 (0)