diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 889462f024d..4b35282e53c 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -11,7 +11,7 @@ from mlos_bench.environments.base_environment import Environment from mlos_bench.environments.composite_env import CompositeEnv from mlos_bench.services.config_persistence import ConfigPersistenceService -from mlos_bench.tests.config import locate_config_examples +from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) @@ -39,6 +39,14 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]: ) assert configs +test_configs = locate_config_examples( + BUILTIN_TEST_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, +) +assert test_configs +configs.extend(test_configs) + @pytest.mark.parametrize("config_path", configs) def test_load_environment_config_examples( diff --git a/mlos_bench/mlos_bench/tests/config/experiments/experiment_test_config.jsonc b/mlos_bench/mlos_bench/tests/config/experiments/experiment_test_config.jsonc index 2ca87c6f215..c6f98c4963a 100644 --- a/mlos_bench/mlos_bench/tests/config/experiments/experiment_test_config.jsonc +++ b/mlos_bench/mlos_bench/tests/config/experiments/experiment_test_config.jsonc @@ -15,6 +15,10 @@ "resourceGroup": "mlos-autotuning-test-rg", "location": "eastus", "vmName": "vmTestName", + "ssh_username": "testuser", + "ssh_priv_key_path": "/home/testuser/.ssh/id_rsa", + "ssh_hostname": "${vmName}", + "ssh_port": 22, "tunable_params_map": { "linux-runtime": ["linux-scheduler", "linux-swap"], "linux-boot": ["linux-kernel-boot"], diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index fceecd89f0d..a407275438b 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -10,7 +10,7 @@ from mlos_bench.config.schemas import ConfigSchema from mlos_bench.optimizers.base_optimizer import Optimizer from mlos_bench.services.config_persistence import ConfigPersistenceService -from mlos_bench.tests.config import locate_config_examples +from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import get_class_from_name @@ -34,6 +34,14 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]: ) assert configs +test_configs = locate_config_examples( + BUILTIN_TEST_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, +) +# assert test_configs +configs.extend(test_configs) + @pytest.mark.parametrize("config_path", configs) def test_load_optimizer_config_examples( diff --git a/mlos_bench/mlos_bench/tests/config/schedulers/__init__.py b/mlos_bench/mlos_bench/tests/config/schedulers/__init__.py new file mode 100644 index 00000000000..111238e6ac9 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schedulers/__init__.py @@ -0,0 +1,5 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""Unit tests for the mlos_bench Scheduler configs.""" diff --git a/mlos_bench/mlos_bench/tests/config/schedulers/conftest.py b/mlos_bench/mlos_bench/tests/config/schedulers/conftest.py new file mode 100644 index 00000000000..71368400561 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schedulers/conftest.py @@ -0,0 +1,57 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Pytest fixtures for Scheduler config tests. + +Provides fixtures for creating multiple TrialRunner instances using the mock environment +config. +""" + +from importlib.resources import files + +import pytest + +from mlos_bench.schedulers.trial_runner import TrialRunner +from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.util import path_join + +# pylint: disable=redefined-outer-name + +TRIAL_RUNNERS_COUNT = 4 + + +@pytest.fixture +def mock_env_config_path() -> str: + """ + Returns the absolute path to the mock environment configuration file. + + This file is used to create TrialRunner instances for testing. + """ + # Use the files() routine to locate the file relative to this directory + return path_join( + str(files("mlos_bench.config").joinpath("environments", "mock", "mock_env.jsonc")), + abs_path=True, + ) + + +@pytest.fixture +def trial_runners( + config_loader_service: ConfigPersistenceService, + mock_env_config_path: str, +) -> list[TrialRunner]: + """ + Fixture that returns a list of TrialRunner instances using the mock environment + config. + + Returns + ------- + list[TrialRunner] + List of TrialRunner instances created from the mock environment config. + """ + return TrialRunner.create_from_json( + config_loader=config_loader_service, + env_json=mock_env_config_path, + num_trial_runners=TRIAL_RUNNERS_COUNT, + ) diff --git a/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py b/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py new file mode 100644 index 00000000000..ad8f9248acd --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py @@ -0,0 +1,85 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""Tests for loading scheduler config examples.""" +import logging + +import pytest + +import mlos_bench.tests.optimizers.fixtures +import mlos_bench.tests.storage.sql.fixtures +from mlos_bench.config.schemas.config_schemas import ConfigSchema +from mlos_bench.optimizers.mock_optimizer import MockOptimizer +from mlos_bench.schedulers.base_scheduler import Scheduler +from mlos_bench.schedulers.trial_runner import TrialRunner +from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.storage.sql.storage import SqlStorage +from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples +from mlos_bench.util import get_class_from_name + +mock_opt = mlos_bench.tests.optimizers.fixtures.mock_opt +storage = mlos_bench.tests.storage.sql.fixtures.storage + + +_LOG = logging.getLogger(__name__) +_LOG.setLevel(logging.DEBUG) + +# pylint: disable=redefined-outer-name + +# Get the set of configs to test. +CONFIG_TYPE = "schedulers" + + +def filter_configs(configs_to_filter: list[str]) -> list[str]: + """If necessary, filter out json files that aren't for the module we're testing.""" + return configs_to_filter + + +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, +) +assert configs + +test_configs = locate_config_examples( + BUILTIN_TEST_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, +) +# assert test_configs +configs.extend(test_configs) + + +@pytest.mark.parametrize("config_path", configs) +def test_load_scheduler_config_examples( + config_loader_service: ConfigPersistenceService, + config_path: str, + mock_env_config_path: str, + trial_runners: list[TrialRunner], + storage: SqlStorage, + mock_opt: MockOptimizer, +) -> None: + """Tests loading a config example.""" + # pylint: disable=too-many-arguments,too-many-positional-arguments + config = config_loader_service.load_config(config_path, ConfigSchema.SCHEDULER) + assert isinstance(config, dict) + cls = get_class_from_name(config["class"]) + assert issubclass(cls, Scheduler) + global_config = { + # Required configs generally provided by the Launcher. + "experiment_id": f"test_experiment_{__name__}", + "trial_id": 1, + } + # Make an instance of the class based on the config. + scheduler_inst = config_loader_service.build_scheduler( + config=config, + global_config=global_config, + trial_runners=trial_runners, + optimizer=mock_opt, + storage=storage, + root_env_config=mock_env_config_path, + ) + assert scheduler_inst is not None + assert isinstance(scheduler_inst, cls) diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index beb0b1d018e..96df98b29d2 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -10,7 +10,7 @@ from mlos_bench.config.schemas.config_schemas import ConfigSchema from mlos_bench.services.base_service import Service from mlos_bench.services.config_persistence import ConfigPersistenceService -from mlos_bench.tests.config import locate_config_examples +from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples _LOG = logging.getLogger(__name__) _LOG.setLevel(logging.DEBUG) @@ -40,6 +40,14 @@ def predicate(config_path: str) -> bool: ) assert configs +test_configs = locate_config_examples( + BUILTIN_TEST_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, +) +assert test_configs +configs.extend(test_configs) + @pytest.mark.parametrize("config_path", configs) def test_load_service_config_examples( diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index e3696a85fad..680b3bacf1f 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -10,7 +10,7 @@ from mlos_bench.config.schemas.config_schemas import ConfigSchema from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.storage.base_storage import Storage -from mlos_bench.tests.config import locate_config_examples +from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples from mlos_bench.util import get_class_from_name _LOG = logging.getLogger(__name__) @@ -33,6 +33,14 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]: ) assert configs +test_configs = locate_config_examples( + BUILTIN_TEST_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, +) +# assert test_configs +configs.extend(test_configs) + @pytest.mark.parametrize("config_path", configs) def test_load_storage_config_examples( diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py index aaa6b14753a..f1c758bca6d 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py +++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py @@ -2,170 +2,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -"""Test fixtures for mlos_bench optimizers.""" - - -import pytest - -from mlos_bench.optimizers.manual_optimizer import ManualOptimizer -from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer -from mlos_bench.optimizers.mock_optimizer import MockOptimizer -from mlos_bench.tests import SEED -from mlos_bench.tunables.tunable_groups import TunableGroups - -# pylint: disable=redefined-outer-name - - -@pytest.fixture -def mock_configs() -> list[dict]: - """Mock configurations of earlier experiments.""" - return [ - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 50000, - "kernel_sched_latency_ns": 1000000, - }, - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000, - "kernel_sched_latency_ns": 2000000, - }, - { - "vmSize": "Standard_B4ms", - "idle": "mwait", - "kernel_sched_migration_cost_ns": -1, # Special value - "kernel_sched_latency_ns": 3000000, - }, - { - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 200000, - "kernel_sched_latency_ns": 4000000, - }, - ] - - -@pytest.fixture -def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: - """Test fixture for MockOptimizer that ignores the initial configuration.""" - return MockOptimizer( - tunables=tunable_groups, - service=None, - config={ - "optimization_targets": {"score": "min"}, - "max_suggestions": 5, - "start_with_defaults": False, - "seed": SEED, - }, - ) - - -@pytest.fixture -def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: - """Test fixture for MockOptimizer.""" - return MockOptimizer( - tunables=tunable_groups, - service=None, - config={"optimization_targets": {"score": "min"}, "max_suggestions": 5, "seed": SEED}, - ) - - -@pytest.fixture -def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: - """Test fixture for MockOptimizer.""" - return MockOptimizer( - tunables=tunable_groups, - service=None, - config={"optimization_targets": {"score": "max"}, "max_suggestions": 10, "seed": SEED}, - ) - - -@pytest.fixture -def flaml_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """Test fixture for mlos_core FLAML optimizer.""" - return MlosCoreOptimizer( - tunables=tunable_groups, - service=None, - config={ - "optimization_targets": {"score": "min"}, - "max_suggestions": 15, - "optimizer_type": "FLAML", - "seed": SEED, - }, - ) - - -@pytest.fixture -def flaml_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """Test fixture for mlos_core FLAML optimizer.""" - return MlosCoreOptimizer( - tunables=tunable_groups, - service=None, - config={ - "optimization_targets": {"score": "max"}, - "max_suggestions": 15, - "optimizer_type": "FLAML", - "seed": SEED, - }, - ) - - -# FIXME: SMAC's RF model can be non-deterministic at low iterations, which are -# normally calculated as a percentage of the max_suggestions and number of -# tunable dimensions, so for now we set the initial random samples equal to the -# number of iterations and control them with a seed. - -SMAC_ITERATIONS = 10 - - -@pytest.fixture -def smac_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """Test fixture for mlos_core SMAC optimizer.""" - return MlosCoreOptimizer( - tunables=tunable_groups, - service=None, - config={ - "optimization_targets": {"score": "min"}, - "max_suggestions": SMAC_ITERATIONS, - "optimizer_type": "SMAC", - "seed": SEED, - "output_directory": None, - # See Above - "n_random_init": SMAC_ITERATIONS, - "max_ratio": 1.0, - }, - ) - - -@pytest.fixture -def smac_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """Test fixture for mlos_core SMAC optimizer.""" - return MlosCoreOptimizer( - tunables=tunable_groups, - service=None, - config={ - "optimization_targets": {"score": "max"}, - "max_suggestions": SMAC_ITERATIONS, - "optimizer_type": "SMAC", - "seed": SEED, - "output_directory": None, - # See Above - "n_random_init": SMAC_ITERATIONS, - "max_ratio": 1.0, - }, - ) - - -@pytest.fixture -def manual_opt(tunable_groups: TunableGroups, mock_configs: list[dict]) -> ManualOptimizer: - """Test fixture for ManualOptimizer.""" - return ManualOptimizer( - tunables=tunable_groups, - service=None, - config={ - "max_cycles": 2, - "tunable_values_cycle": mock_configs, - }, - ) +"""Export test fixtures for mlos_bench optimizers.""" + +import mlos_bench.tests.optimizers.fixtures + +mock_configs = mlos_bench.tests.optimizers.fixtures.mock_configs +mock_opt_no_defaults = mlos_bench.tests.optimizers.fixtures.mock_opt_no_defaults +mock_opt = mlos_bench.tests.optimizers.fixtures.mock_opt +mock_opt_max = mlos_bench.tests.optimizers.fixtures.mock_opt_max +flaml_opt = mlos_bench.tests.optimizers.fixtures.flaml_opt +flaml_opt_max = mlos_bench.tests.optimizers.fixtures.flaml_opt_max +smac_opt = mlos_bench.tests.optimizers.fixtures.smac_opt +smac_opt_max = mlos_bench.tests.optimizers.fixtures.smac_opt_max +manual_opt = mlos_bench.tests.optimizers.fixtures.manual_opt diff --git a/mlos_bench/mlos_bench/tests/optimizers/fixtures.py b/mlos_bench/mlos_bench/tests/optimizers/fixtures.py new file mode 100644 index 00000000000..aaa6b14753a --- /dev/null +++ b/mlos_bench/mlos_bench/tests/optimizers/fixtures.py @@ -0,0 +1,171 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""Test fixtures for mlos_bench optimizers.""" + + +import pytest + +from mlos_bench.optimizers.manual_optimizer import ManualOptimizer +from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer +from mlos_bench.optimizers.mock_optimizer import MockOptimizer +from mlos_bench.tests import SEED +from mlos_bench.tunables.tunable_groups import TunableGroups + +# pylint: disable=redefined-outer-name + + +@pytest.fixture +def mock_configs() -> list[dict]: + """Mock configurations of earlier experiments.""" + return [ + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 50000, + "kernel_sched_latency_ns": 1000000, + }, + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000, + "kernel_sched_latency_ns": 2000000, + }, + { + "vmSize": "Standard_B4ms", + "idle": "mwait", + "kernel_sched_migration_cost_ns": -1, # Special value + "kernel_sched_latency_ns": 3000000, + }, + { + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 200000, + "kernel_sched_latency_ns": 4000000, + }, + ] + + +@pytest.fixture +def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: + """Test fixture for MockOptimizer that ignores the initial configuration.""" + return MockOptimizer( + tunables=tunable_groups, + service=None, + config={ + "optimization_targets": {"score": "min"}, + "max_suggestions": 5, + "start_with_defaults": False, + "seed": SEED, + }, + ) + + +@pytest.fixture +def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: + """Test fixture for MockOptimizer.""" + return MockOptimizer( + tunables=tunable_groups, + service=None, + config={"optimization_targets": {"score": "min"}, "max_suggestions": 5, "seed": SEED}, + ) + + +@pytest.fixture +def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: + """Test fixture for MockOptimizer.""" + return MockOptimizer( + tunables=tunable_groups, + service=None, + config={"optimization_targets": {"score": "max"}, "max_suggestions": 10, "seed": SEED}, + ) + + +@pytest.fixture +def flaml_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: + """Test fixture for mlos_core FLAML optimizer.""" + return MlosCoreOptimizer( + tunables=tunable_groups, + service=None, + config={ + "optimization_targets": {"score": "min"}, + "max_suggestions": 15, + "optimizer_type": "FLAML", + "seed": SEED, + }, + ) + + +@pytest.fixture +def flaml_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer: + """Test fixture for mlos_core FLAML optimizer.""" + return MlosCoreOptimizer( + tunables=tunable_groups, + service=None, + config={ + "optimization_targets": {"score": "max"}, + "max_suggestions": 15, + "optimizer_type": "FLAML", + "seed": SEED, + }, + ) + + +# FIXME: SMAC's RF model can be non-deterministic at low iterations, which are +# normally calculated as a percentage of the max_suggestions and number of +# tunable dimensions, so for now we set the initial random samples equal to the +# number of iterations and control them with a seed. + +SMAC_ITERATIONS = 10 + + +@pytest.fixture +def smac_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: + """Test fixture for mlos_core SMAC optimizer.""" + return MlosCoreOptimizer( + tunables=tunable_groups, + service=None, + config={ + "optimization_targets": {"score": "min"}, + "max_suggestions": SMAC_ITERATIONS, + "optimizer_type": "SMAC", + "seed": SEED, + "output_directory": None, + # See Above + "n_random_init": SMAC_ITERATIONS, + "max_ratio": 1.0, + }, + ) + + +@pytest.fixture +def smac_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer: + """Test fixture for mlos_core SMAC optimizer.""" + return MlosCoreOptimizer( + tunables=tunable_groups, + service=None, + config={ + "optimization_targets": {"score": "max"}, + "max_suggestions": SMAC_ITERATIONS, + "optimizer_type": "SMAC", + "seed": SEED, + "output_directory": None, + # See Above + "n_random_init": SMAC_ITERATIONS, + "max_ratio": 1.0, + }, + ) + + +@pytest.fixture +def manual_opt(tunable_groups: TunableGroups, mock_configs: list[dict]) -> ManualOptimizer: + """Test fixture for ManualOptimizer.""" + return ManualOptimizer( + tunables=tunable_groups, + service=None, + config={ + "max_cycles": 2, + "tunable_values_cycle": mock_configs, + }, + ) diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index a1437052823..c510793fac1 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -12,6 +12,7 @@ # Expose some of those as local names so they can be picked up as fixtures by pytest. storage = sql_storage_fixtures.storage +sqlite_storage = sql_storage_fixtures.sqlite_storage exp_storage = sql_storage_fixtures.exp_storage exp_no_tunables_storage = sql_storage_fixtures.exp_no_tunables_storage mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index cb83bffd4ff..0bebeeff824 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -4,6 +4,9 @@ # """Test fixtures for mlos_bench storage.""" +import json +import os +import tempfile from collections.abc import Generator from random import seed as rand_seed @@ -15,6 +18,7 @@ from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.sql.storage import SqlStorage +from mlos_bench.storage.storage_factory import from_config from mlos_bench.tests import SEED from mlos_bench.tests.storage import ( CONFIG_TRIAL_REPEAT_COUNT, @@ -26,6 +30,38 @@ # pylint: disable=redefined-outer-name +@pytest.fixture +def sqlite_storage() -> Generator[SqlStorage]: + """ + Fixture for file based SQLite storage in a temporary directory. + + Yields + ------ + Generator[SqlStorage] + + Notes + ----- + Can't be used in parallel tests on Windows. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "mlos_bench.sqlite") + config_str = json.dumps( + { + "class": "mlos_bench.storage.sql.storage.SqlStorage", + "config": { + "drivername": "sqlite", + "database": db_path, + "lazy_schema_create": False, + }, + } + ) + + storage = from_config(config_str) + assert isinstance(storage, SqlStorage) + storage.update_schema() + yield storage + + @pytest.fixture def storage() -> SqlStorage: """Test fixture for in-memory SQLite3 storage.""" diff --git a/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py b/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py index 3d5053837be..7871e7f68ca 100644 --- a/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py +++ b/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py @@ -3,11 +3,8 @@ # Licensed under the MIT License. # """Test pickling and unpickling of Storage, and restoring Experiment and Trial by id.""" -import json -import os import pickle import sys -import tempfile from datetime import datetime from typing import Literal @@ -16,7 +13,6 @@ from mlos_bench.environments.status import Status from mlos_bench.storage.sql.storage import SqlStorage -from mlos_bench.storage.storage_factory import from_config from mlos_bench.tunables.tunable_groups import TunableGroups @@ -26,72 +22,59 @@ sys.platform == "win32", reason="Windows doesn't support multiple processes accessing the same file.", ) -def test_storage_pickle_restore_experiment_and_trial(tunable_groups: TunableGroups) -> None: +def test_storage_pickle_restore_experiment_and_trial( + sqlite_storage: SqlStorage, + tunable_groups: TunableGroups, +) -> None: """Check that we can pickle and unpickle the Storage object, and restore Experiment and Trial by id. """ - # pylint: disable=too-many-locals - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "mlos_bench.sqlite") - config_str = json.dumps( - { - "class": "mlos_bench.storage.sql.storage.SqlStorage", - "config": { - "drivername": "sqlite", - "database": db_path, - "lazy_schema_create": False, - }, - } - ) + storage = sqlite_storage + # Create an Experiment and a Trial + opt_targets: dict[str, Literal["min", "max"]] = {"metric": "min"} + experiment = storage.experiment( + experiment_id="experiment_id", + trial_id=0, + root_env_config="dummy_env.json", + description="Pickle test experiment", + tunables=tunable_groups, + opt_targets=opt_targets, + ) + with experiment: + trial = experiment.new_trial(tunable_groups) + trial_id_created = trial.trial_id + trial.set_trial_runner(1) + trial.update(Status.RUNNING, datetime.now(UTC)) - storage = from_config(config_str) - storage.update_schema() + # Pickle and unpickle the Storage object + pickled = pickle.dumps(storage) + restored_storage = pickle.loads(pickled) + assert isinstance(restored_storage, SqlStorage) - # Create an Experiment and a Trial - opt_targets: dict[str, Literal["min", "max"]] = {"metric": "min"} - experiment = storage.experiment( - experiment_id="experiment_id", - trial_id=0, - root_env_config="dummy_env.json", - description="Pickle test experiment", - tunables=tunable_groups, - opt_targets=opt_targets, - ) - with experiment: - trial = experiment.new_trial(tunable_groups) - trial_id_created = trial.trial_id - trial.set_trial_runner(1) - trial.update(Status.RUNNING, datetime.now(UTC)) + # Restore the Experiment from storage by id and check that it matches the original + restored_experiment = restored_storage.get_experiment_by_id( + experiment_id=experiment.experiment_id, + tunables=tunable_groups, + opt_targets=opt_targets, + ) + assert restored_experiment is not None + assert restored_experiment is not experiment + assert restored_experiment.experiment_id == experiment.experiment_id + assert restored_experiment.description == experiment.description + assert restored_experiment.root_env_config == experiment.root_env_config + assert restored_experiment.tunables == experiment.tunables + assert restored_experiment.opt_targets == experiment.opt_targets + with restored_experiment: + # trial_id should have been restored during __enter__ + assert restored_experiment.trial_id == experiment.trial_id - # Pickle and unpickle the Storage object - pickled = pickle.dumps(storage) - restored_storage = pickle.loads(pickled) - assert isinstance(restored_storage, SqlStorage) - - # Restore the Experiment from storage by id and check that it matches the original - restored_experiment = restored_storage.get_experiment_by_id( - experiment_id=experiment.experiment_id, - tunables=tunable_groups, - opt_targets=opt_targets, - ) - assert restored_experiment is not None - assert restored_experiment is not experiment - assert restored_experiment.experiment_id == experiment.experiment_id - assert restored_experiment.description == experiment.description - assert restored_experiment.root_env_config == experiment.root_env_config - assert restored_experiment.tunables == experiment.tunables - assert restored_experiment.opt_targets == experiment.opt_targets - with restored_experiment: - # trial_id should have been restored during __enter__ - assert restored_experiment.trial_id == experiment.trial_id - - # Restore the Trial from storage by id and check that it matches the original - restored_trial = restored_experiment.get_trial_by_id(trial_id_created) - assert restored_trial is not None - assert restored_trial is not trial - assert restored_trial.trial_id == trial.trial_id - assert restored_trial.experiment_id == trial.experiment_id - assert restored_trial.tunables == trial.tunables - assert restored_trial.status == trial.status - assert restored_trial.config() == trial.config() - assert restored_trial.trial_runner_id == trial.trial_runner_id + # Restore the Trial from storage by id and check that it matches the original + restored_trial = restored_experiment.get_trial_by_id(trial_id_created) + assert restored_trial is not None + assert restored_trial is not trial + assert restored_trial.trial_id == trial.trial_id + assert restored_trial.experiment_id == trial.experiment_id + assert restored_trial.tunables == trial.tunables + assert restored_trial.status == trial.status + assert restored_trial.config() == trial.config() + assert restored_trial.trial_runner_id == trial.trial_runner_id