Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mlos_bench.environments.base_environment import Environment
from mlos_bench.environments.composite_env import CompositeEnv
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.tests.config import locate_config_examples
from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples
from mlos_bench.tunables.tunable_groups import TunableGroups

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -39,6 +39,14 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]:
)
assert configs

test_configs = locate_config_examples(
BUILTIN_TEST_CONFIG_PATH,
CONFIG_TYPE,
filter_configs,
)
assert test_configs
configs.extend(test_configs)


@pytest.mark.parametrize("config_path", configs)
def test_load_environment_config_examples(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
"resourceGroup": "mlos-autotuning-test-rg",
"location": "eastus",
"vmName": "vmTestName",
"ssh_username": "testuser",
"ssh_priv_key_path": "/home/testuser/.ssh/id_rsa",
"ssh_hostname": "${vmName}",
"ssh_port": 22,
"tunable_params_map": {
"linux-runtime": ["linux-scheduler", "linux-swap"],
"linux-boot": ["linux-kernel-boot"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.optimizers.base_optimizer import Optimizer
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.tests.config import locate_config_examples
from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import get_class_from_name

Expand All @@ -34,6 +34,14 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]:
)
assert configs

test_configs = locate_config_examples(
BUILTIN_TEST_CONFIG_PATH,
CONFIG_TYPE,
filter_configs,
)
# assert test_configs
configs.extend(test_configs)


@pytest.mark.parametrize("config_path", configs)
def test_load_optimizer_config_examples(
Expand Down
5 changes: 5 additions & 0 deletions mlos_bench/mlos_bench/tests/config/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""Unit tests for the mlos_bench Scheduler configs."""
57 changes: 57 additions & 0 deletions mlos_bench/mlos_bench/tests/config/schedulers/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""
Pytest fixtures for Scheduler config tests.

Provides fixtures for creating multiple TrialRunner instances using the mock environment
config.
"""

from importlib.resources import files

import pytest

from mlos_bench.schedulers.trial_runner import TrialRunner
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.util import path_join

# pylint: disable=redefined-outer-name

TRIAL_RUNNERS_COUNT = 4


@pytest.fixture
def mock_env_config_path() -> str:
"""
Returns the absolute path to the mock environment configuration file.

This file is used to create TrialRunner instances for testing.
"""
# Use the files() routine to locate the file relative to this directory
return path_join(
str(files("mlos_bench.config").joinpath("environments", "mock", "mock_env.jsonc")),
abs_path=True,
)


@pytest.fixture
def trial_runners(
config_loader_service: ConfigPersistenceService,
mock_env_config_path: str,
) -> list[TrialRunner]:
"""
Fixture that returns a list of TrialRunner instances using the mock environment
config.

Returns
-------
list[TrialRunner]
List of TrialRunner instances created from the mock environment config.
"""
return TrialRunner.create_from_json(
config_loader=config_loader_service,
env_json=mock_env_config_path,
num_trial_runners=TRIAL_RUNNERS_COUNT,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#
"""Tests for loading scheduler config examples."""
import logging

import pytest

import mlos_bench.tests.optimizers.fixtures
import mlos_bench.tests.storage.sql.fixtures
from mlos_bench.config.schemas.config_schemas import ConfigSchema
from mlos_bench.optimizers.mock_optimizer import MockOptimizer
from mlos_bench.schedulers.base_scheduler import Scheduler
from mlos_bench.schedulers.trial_runner import TrialRunner
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.storage.sql.storage import SqlStorage
from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples
from mlos_bench.util import get_class_from_name

mock_opt = mlos_bench.tests.optimizers.fixtures.mock_opt
sqlite_storage = mlos_bench.tests.storage.sql.fixtures.sqlite_storage


_LOG = logging.getLogger(__name__)
_LOG.setLevel(logging.DEBUG)

# pylint: disable=redefined-outer-name

# Get the set of configs to test.
CONFIG_TYPE = "schedulers"


def filter_configs(configs_to_filter: list[str]) -> list[str]:
"""If necessary, filter out json files that aren't for the module we're testing."""
return configs_to_filter


configs = locate_config_examples(
ConfigPersistenceService.BUILTIN_CONFIG_PATH,
CONFIG_TYPE,
filter_configs,
)
assert configs

test_configs = locate_config_examples(
BUILTIN_TEST_CONFIG_PATH,
CONFIG_TYPE,
filter_configs,
)
# assert test_configs
configs.extend(test_configs)


@pytest.mark.parametrize("config_path", configs)
def test_load_scheduler_config_examples(
config_loader_service: ConfigPersistenceService,
config_path: str,
mock_env_config_path: str,
trial_runners: list[TrialRunner],
sqlite_storage: SqlStorage,
mock_opt: MockOptimizer,
) -> None:
"""Tests loading a config example."""
# pylint: disable=too-many-arguments,too-many-positional-arguments
config = config_loader_service.load_config(config_path, ConfigSchema.SCHEDULER)
assert isinstance(config, dict)
cls = get_class_from_name(config["class"])
assert issubclass(cls, Scheduler)
global_config = {
# Required configs generally provided by the Launcher.
"experiment_id": f"test_experiment_{__name__}",
"trial_id": 1,
}
# Make an instance of the class based on the config.
scheduler_inst = config_loader_service.build_scheduler(
config=config,
global_config=global_config,
trial_runners=trial_runners,
optimizer=mock_opt,
storage=sqlite_storage,
root_env_config=mock_env_config_path,
)
assert scheduler_inst is not None
assert isinstance(scheduler_inst, cls)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mlos_bench.config.schemas.config_schemas import ConfigSchema
from mlos_bench.services.base_service import Service
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.tests.config import locate_config_examples
from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples

_LOG = logging.getLogger(__name__)
_LOG.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -40,6 +40,14 @@ def predicate(config_path: str) -> bool:
)
assert configs

test_configs = locate_config_examples(
BUILTIN_TEST_CONFIG_PATH,
CONFIG_TYPE,
filter_configs,
)
assert test_configs
configs.extend(test_configs)


@pytest.mark.parametrize("config_path", configs)
def test_load_service_config_examples(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mlos_bench.config.schemas.config_schemas import ConfigSchema
from mlos_bench.services.config_persistence import ConfigPersistenceService
from mlos_bench.storage.base_storage import Storage
from mlos_bench.tests.config import locate_config_examples
from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples
from mlos_bench.util import get_class_from_name

_LOG = logging.getLogger(__name__)
Expand All @@ -33,6 +33,14 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]:
)
assert configs

test_configs = locate_config_examples(
BUILTIN_TEST_CONFIG_PATH,
CONFIG_TYPE,
filter_configs,
)
# assert test_configs
configs.extend(test_configs)


@pytest.mark.parametrize("config_path", configs)
def test_load_storage_config_examples(
Expand Down
Loading
Loading