diff --git a/.vscode/settings.json b/.vscode/settings.json index a01d1d6418f..f29dec766e2 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,6 +8,7 @@ "**/__pycache__/": true, "**/node_modules/": true, "**/*.egg-info": true, + "mlos_*/build/": true, "doc/source/autoapi/": true, "doc/build/doctrees/": true, "doc/build/html/": true, diff --git a/mlos_bench/mlos_bench/config/schemas/cli/globals-schema.json b/mlos_bench/mlos_bench/config/schemas/cli/globals-schema.json index 015b4a6e62c..39e60e3249b 100644 --- a/mlos_bench/mlos_bench/config/schemas/cli/globals-schema.json +++ b/mlos_bench/mlos_bench/config/schemas/cli/globals-schema.json @@ -24,6 +24,9 @@ }, "optimization_targets": { "$ref": "./common-defs-subschemas.json#/$defs/optimization_targets" + }, + "mock_trial_data": { + "$ref": "../environments/mock-env-subschema.json#/$defs/mock_trial_data" } }, "additionalProperties": { diff --git a/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json b/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json index cb2de6c719f..b453c8573aa 100644 --- a/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json +++ b/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json @@ -3,6 +3,136 @@ "$id": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json", "title": "mlos_bench MockEnv config", "description": "Config instance for a mlos_bench MockEnv", + + "$defs": { + "mock_trial_common_phase_data": { + "type": "object", + "properties": { + "sleep": { + "type": "number", + "description": "Optional time to sleep (in seconds) before returning from this phase of the trial.", + "examples": [0, 0.1, 0.5, 1, 2], + "minimum": 0, + "maximum": 60 + }, + "exception": { + "type": "string", + "description": "Optional exception message to raise during phase." + } + } + }, + "mock_trial_status_run_phase_data": { + "type": "object", + "properties": { + "status": { + "description": "The status to report for this phase of the trial. Default is phase dependent.", + "enum": [ + "UNKNOWN", + "PENDING", + "READY", + "RUNNING", + "SUCCEEDED", + "CANCELED", + "FAILED", + "TIMED_OUT" + ] + }, + "metrics": { + "type": "object", + "description": "A dictionary of metrics for this phase of the trial.", + "additionalProperties": { + "type": [ + "number", + "string", + "boolean" + ], + "description": "The value of the metric." + }, + "examples": [ + { + "score": 0.95, + "color": "green" + }, + { + "accuracy": 0.85, + "loss": 0.15 + } + ], + "minProperties": 0 + } + } + }, + "mock_trial_data_item": { + "description": "Mock data for a single trial, split by phase", + "type": "object", + "properties": { + "run": { + "description": "A dictionary of trial data for the run phase.", + "type": "object", + "allOf": [ + { + "$ref": "#/$defs/mock_trial_common_phase_data" + }, + { + "$ref": "#/$defs/mock_trial_status_run_phase_data" + } + ], + "minProperties": 1, + "unevaluatedProperties": false + }, + "status": { + "description": "A dictionary of trial data for the status phase.", + "type": "object", + "allOf": [ + { + "$ref": "#/$defs/mock_trial_common_phase_data" + }, + { + "$ref": "#/$defs/mock_trial_status_run_phase_data" + } + ], + "minProperties": 1, + "unevaluatedProperties": false + }, + "setup": { + "description": "A dictionary of trial data for the setup phase.", + "type": "object", + "allOf": [ + { + "$ref": "#/$defs/mock_trial_common_phase_data" + } + ], + "minProperties": 1, + "unevaluatedProperties": false + }, + "teardown": { + "description": "A dictionary of trial data for the teardown phase.", + "type": "object", + "allOf": [ + { + "$ref": "#/$defs/mock_trial_common_phase_data" + } + ], + "minProperties": 1, + "unevaluatedProperties": false + } + }, + "unevaluatedProperties": false, + "minProperties": 1 + }, + "mock_trial_data": { + "description": "A set of mock trial data to use for testing, keyed by trial id. Used by MockEnv.", + "type": "object", + "patternProperties": { + "^[1-9][0-9]*$": { + "$ref": "#/$defs/mock_trial_data_item" + } + }, + "unevaluatedProperties": false, + "minProperties": 1 + } + }, + "type": "object", "properties": { "class": { @@ -42,6 +172,9 @@ }, "minItems": 1, "uniqueItems": true + }, + "mock_trial_data": { + "$ref": "#/$defs/mock_trial_data" } } } diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index 094085c78b5..fe40025f95d 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -363,6 +363,92 @@ def parameters(self) -> dict[str, TunableValue]: """ return self._params.copy() + @property + def current_trial_id(self) -> int: + """ + Get the current trial ID. + + This value can be used in scripts or environment variables to help + identify the Trial this Environment is currently running. + + Returns + ------- + trial_id : int + The current trial ID. + + Notes + ----- + This method is used to identify the current trial ID for the environment. + It is expected to be called *after* the base + :py:meth:`Environment.setup` method has been called and parameters have + been assigned. + """ + val = self._params["trial_id"] + assert isinstance(val, int), ( + "Expected trial_id to be an int, but got %s (type %s): %s", + val, + type(val), + self._params, + ) + return val + + @property + def trial_runner_id(self) -> int: + """ + Get the ID of the :py:class:`~.mlos_bench.schedulers.trial_runner.TrialRunner` + for this Environment. + + This value can be used in scripts or environment variables to help + identify the TrialRunner for this Environment. + + Returns + ------- + trial_runner_id : int + The trial runner ID. + + Notes + ----- + This shouldn't change during the lifetime of the Environment since each + Environment is assigned to a single TrialRunner. + """ + val = self._params["trial_runner_id"] + assert isinstance(val, int), ( + "Expected trial_runner_id to be an int, but got %s (type %s)", + val, + type(val), + ) + return val + + @property + def experiment_id(self) -> str: + """ + Get the ID of the experiment. + + This value can be used in scripts or environment variables to help + identify the TrialRunner for this Environment. + + Returns + ------- + experiment_id : str + The ID of the experiment. + + Notes + ----- + This value comes from the globals config or ``mlos_bench`` CLI arguments + in the experiment setup. + + See Also + -------- + mlos_bench.config : documentation on the configuration system + """ + val = self._params["experiment_id"] + assert isinstance(val, str), ( + "Expected experiment_id to be an int, but got %s (type %s)", + val, + type(val), + ) + return val + def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool: """ Set up a new benchmark environment, if necessary. This method must be diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index ac6d9b7f001..c565b1adc66 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -6,6 +6,9 @@ import logging import random +import time +from copy import deepcopy +from dataclasses import dataclass from datetime import datetime from typing import Any @@ -21,6 +24,131 @@ _LOG = logging.getLogger(__name__) +@dataclass +class MockTrialPhaseData: + """Mock trial data for a specific phase of a trial.""" + + phase: str + """Phase of the trial data (e.g., setup, run, status, teardown).""" + + status: Status + """Status response for the phase.""" + + metrics: dict[str, TunableValue] | None = None + """Metrics response for the phase.""" + + sleep: float | None = 0.0 + """Optional sleep time in seconds to simulate phase execution time.""" + + exception: str | None = None + """Message of an exception to raise for the phase.""" + + @staticmethod + def from_dict(phase: str, data: dict | None) -> "MockTrialPhaseData": + """ + Create a MockTrialPhaseData instance from a dictionary. + + Parameters + ---------- + phase : str + Phase of the trial data. + data : dict | None + Dictionary containing the phase data. + + Returns + ------- + MockTrialPhaseData + Instance of MockTrialPhaseData. + """ + data = data or {} + assert phase in {"setup", "run", "status", "teardown"}, f"Invalid phase: {phase}" + if phase in {"teardown", "status"}: + # setup/teardown phase is not expected to have metrics or status. + assert "metrics" not in data, f"Unexpected metrics data in {phase} phase: {data}" + assert "status" not in data, f"Unexpected status data in {phase} phase: {data}" + if "sleep" in data: + assert isinstance( + data["sleep"], (int, float) + ), f"Invalid sleep in {phase} phase: {data}" + assert 60 >= data["sleep"] >= 0, f"Invalid sleep time in {phase} phase: {data}" + if "metrics" in data: + assert isinstance(data["metrics"], dict), f"Invalid metrics in {phase} phase: {data}" + default_phases = { + "run": Status.SUCCEEDED, + # FIXME: this causes issues if we report RUNNING instead of READY + "status": Status.READY, + } + status = Status.parse(data.get("status", default_phases.get(phase, Status.UNKNOWN))) + return MockTrialPhaseData( + phase=phase, + status=status, + metrics=data.get("metrics"), + sleep=data.get("sleep"), + exception=data.get("exception"), + ) + + +@dataclass +class MockTrialData: + """Mock trial data for a specific trial ID.""" + + trial_id: int + """Trial ID for the mock trial data.""" + setup: MockTrialPhaseData + """Setup phase data for the trial.""" + run: MockTrialPhaseData + """Run phase data for the trial.""" + status: MockTrialPhaseData + """Status phase data for the trial.""" + teardown: MockTrialPhaseData + """Teardown phase data for the trial.""" + + @staticmethod + def from_dict(trial_id: int, data: dict) -> "MockTrialData": + """ + Create a MockTrialData instance from a dictionary. + + Parameters + ---------- + trial_id : int + Trial ID for the mock trial data. + data : dict + Dictionary containing the trial data. + + Returns + ------- + MockTrialData + Instance of MockTrialData. + """ + return MockTrialData( + trial_id=trial_id, + setup=MockTrialPhaseData.from_dict("setup", data.get("setup")), + run=MockTrialPhaseData.from_dict("run", data.get("run")), + status=MockTrialPhaseData.from_dict("status", data.get("status")), + teardown=MockTrialPhaseData.from_dict("teardown", data.get("teardown")), + ) + + @staticmethod + def load_mock_trial_data(mock_trial_data: dict) -> dict[int, "MockTrialData"]: + """ + Load mock trial data from a dictionary. + + Parameters + ---------- + mock_trial_data : dict + Dictionary containing mock trial data. + + Returns + ------- + dict[int, MockTrialData] + Dictionary of mock trial data keyed by trial ID. + """ + return { + int(trial_id): MockTrialData.from_dict(trial_id=int(trial_id), data=trial_data) + for trial_id, trial_data in mock_trial_data.items() + } + + class MockEnv(Environment): """Scheduler-side environment to mock the benchmark results.""" @@ -55,6 +183,19 @@ def __init__( # pylint: disable=too-many-arguments service: Service An optional service object. Not used by this class. """ + # First allow merging mock_trial_data from the global_config into the + # config so we can check it against the JSON schema for expected data + # types. + if global_config and "mock_trial_data" in global_config: + mock_trial_data = global_config["mock_trial_data"] + if not isinstance(mock_trial_data, dict): + raise ValueError(f"Invalid mock_trial_data in global_config: {mock_trial_data}") + # Merge the mock trial data into the config. + config["mock_trial_data"] = { + **config.get("mock_trial_data", {}), + **mock_trial_data, + } + super().__init__( name=name, config=config, @@ -62,6 +203,9 @@ def __init__( # pylint: disable=too-many-arguments tunables=tunables, service=service, ) + self._mock_trial_data = MockTrialData.load_mock_trial_data( + self.config.get("mock_trial_data", {}) + ) seed = int(self.config.get("mock_env_seed", -1)) self._run_random = random.Random(seed or None) if seed >= 0 else None self._status_random = random.Random(seed or None) if seed >= 0 else None @@ -83,6 +227,67 @@ def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue return {metric: float(score) for metric in self._metrics or []} + @property + def mock_trial_data(self) -> dict[int, MockTrialData]: + """ + Get the mock trial data for all trials. + + Returns + ------- + dict[int, MockTrialData] + Dictionary of mock trial data keyed by trial ID. + """ + return deepcopy(self._mock_trial_data) + + def get_current_mock_trial_data(self) -> MockTrialData: + """ + Gets mock trial data for the current trial ID. + + If no (or missing) mock trial data is found, a new instance of + MockTrialData is created and later filled with random data. + + Note + ---- + This method must be called after the base :py:meth:`Environment.setup` + method is called to ensure the current ``trial_id`` is set. + """ + trial_id = self.current_trial_id + mock_trial_data = self._mock_trial_data.get(trial_id) + if not mock_trial_data: + mock_trial_data = MockTrialData( + trial_id=trial_id, + setup=MockTrialPhaseData.from_dict(phase="setup", data=None), + run=MockTrialPhaseData.from_dict(phase="run", data=None), + status=MockTrialPhaseData.from_dict(phase="status", data=None), + teardown=MockTrialPhaseData.from_dict(phase="teardown", data=None), + ) + # Save the generated data for later. + self._mock_trial_data[trial_id] = mock_trial_data + return mock_trial_data + + def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool: + is_success = super().setup(tunables, global_config) + mock_trial_data = self.get_current_mock_trial_data() + if mock_trial_data.setup.sleep: + _LOG.debug("Sleeping for %s seconds", mock_trial_data.setup.sleep) + time.sleep(mock_trial_data.setup.sleep) + if mock_trial_data.setup.exception: + raise RuntimeError( + f"Mock trial data setup exception: {mock_trial_data.setup.exception}" + ) + return is_success + + def teardown(self) -> None: + mock_trial_data = self.get_current_mock_trial_data() + if mock_trial_data.teardown.sleep: + _LOG.debug("Sleeping for %s seconds", mock_trial_data.teardown.sleep) + time.sleep(mock_trial_data.teardown.sleep) + if mock_trial_data.teardown.exception: + raise RuntimeError( + f"Mock trial data teardown exception: {mock_trial_data.teardown.exception}" + ) + super().teardown() + def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]: """ Produce mock benchmark data for one experiment. @@ -99,8 +304,16 @@ def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]: (status, timestamp, _) = result = super().run() if not status.is_ready(): return result - metrics = self._produce_metrics(self._run_random) - return (Status.SUCCEEDED, timestamp, metrics) + mock_trial_data = self.get_current_mock_trial_data() + if mock_trial_data.run.sleep: + _LOG.debug("Sleeping for %s seconds", mock_trial_data.run.sleep) + time.sleep(mock_trial_data.run.sleep) + if mock_trial_data.run.exception: + raise RuntimeError(f"Mock trial data run exception: {mock_trial_data.run.exception}") + if mock_trial_data.run.metrics is None: + # If no metrics are provided, generate them. + mock_trial_data.run.metrics = self._produce_metrics(self._run_random) + return (mock_trial_data.run.status, timestamp, mock_trial_data.run.metrics) def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]: """ @@ -116,10 +329,26 @@ def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]: (status, timestamp, _) = result = super().status() if not status.is_ready(): return result - metrics = self._produce_metrics(self._status_random) + mock_trial_data = self.get_current_mock_trial_data() + if mock_trial_data.status.sleep: + _LOG.debug("Sleeping for %s seconds", mock_trial_data.status.sleep) + time.sleep(mock_trial_data.status.sleep) + if mock_trial_data.status.exception: + raise RuntimeError( + f"Mock trial data status exception: {mock_trial_data.status.exception}" + ) + if mock_trial_data.status.metrics is None: + # If no metrics are provided, generate them. + # Note: we don't save these in the mock trial data as they may need + # to change to preserve backwards compatibility with previous tests. + metrics = self._produce_metrics(self._status_random) + else: + # If metrics are provided, use them. + # Note: current implementation uses the same metrics for all status + # calls of this trial. + metrics = mock_trial_data.status.metrics return ( - # FIXME: this causes issues if we report RUNNING instead of READY - Status.READY, + mock_trial_data.status.status, timestamp, [(timestamp, metric, score) for (metric, score) in metrics.items()], ) diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py index 6ac4674cfe1..d71eb661834 100644 --- a/mlos_bench/mlos_bench/environments/script_env.py +++ b/mlos_bench/mlos_bench/environments/script_env.py @@ -5,7 +5,7 @@ """ Base scriptable benchmark environment. -TODO: Document how variable propogation works in the script environments using +TODO: Document how variable propagation works in the script environments using shell_env_params, required_args, const_args, etc. """ diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py index 6d76d7206c8..aa3b3e99c16 100644 --- a/mlos_bench/mlos_bench/environments/status.py +++ b/mlos_bench/mlos_bench/environments/status.py @@ -24,21 +24,37 @@ class Status(enum.Enum): TIMED_OUT = 7 @staticmethod - def from_str(status_str: Any) -> "Status": - """Convert a string to a Status enum.""" - if not isinstance(status_str, str): - _LOG.warning("Expected type %s for status: %s", type(status_str), status_str) - status_str = str(status_str) - if status_str.isdigit(): + def parse(status: Any) -> "Status": + """ + Convert the input to a Status enum. + + Parameters + ---------- + status : Any + The status to parse. This can be a string (or string convertible), + int, or Status enum. + + Returns + ------- + Status + The corresponding Status enum value or else UNKNOWN if the input is not + recognized. + """ + if isinstance(status, Status): + return status + if not isinstance(status, str): + _LOG.warning("Expected type %s for status: %s", type(status), status) + status = str(status) + if status.isdigit(): try: - return Status(int(status_str)) + return Status(int(status)) except ValueError: - _LOG.warning("Unknown status: %d", int(status_str)) + _LOG.warning("Unknown status: %d", int(status)) try: - status_str = status_str.upper().strip() - return Status[status_str] + status = status.upper().strip() + return Status[status] except KeyError: - _LOG.warning("Unknown status: %s", status_str) + _LOG.warning("Unknown status: %s", status) return Status.UNKNOWN def is_good(self) -> bool: @@ -113,4 +129,15 @@ def is_timed_out(self) -> bool: Status.TIMED_OUT, } ) -"""The set of completed statuses.""" +""" +The set of completed statuses. + +Includes all statuses that indicate the trial or experiment has finished, either +successfully or not. +This set is used to determine if a trial or experiment has reached a final state. +This includes: +- :py:attr:`.Status.SUCCEEDED`: The trial or experiment completed successfully. +- :py:attr:`.Status.CANCELED`: The trial or experiment was canceled. +- :py:attr:`.Status.FAILED`: The trial or experiment failed. +- :py:attr:`.Status.TIMED_OUT`: The trial or experiment timed out. +""" diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index c728ed7fb20..353ace23f0e 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -55,8 +55,9 @@ def __init__(self, description: str, long_text: str = "", argv: list[str] | None Other required_args values can also be pulled from shell environment variables. - For additional details, please see the website or the README.md files in - the source tree: + For additional details, please see the documentation website or the + README.md files in the source tree: + """ parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog) diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index 44aa9a035e2..72b437b320e 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -356,8 +356,10 @@ def _get_scores( assert scores is not None target_metrics: dict[str, float] = {} for opt_target, opt_dir in self._opt_targets.items(): + if opt_target not in scores: + raise ValueError(f"Score for {opt_target} not found in {scores}.") val = scores[opt_target] - assert val is not None + assert val is not None, f"Score for {opt_target} is None." target_metrics[opt_target] = float(val) * opt_dir return target_metrics diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index 947e34a7da4..a1311b6f953 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -14,6 +14,7 @@ import logging import random from collections.abc import Callable, Sequence +from dataclasses import dataclass from mlos_bench.environments.status import Status from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer @@ -25,6 +26,15 @@ _LOG = logging.getLogger(__name__) +@dataclass +class RegisteredScore: + """A registered score for a trial.""" + + config: TunableGroups + score: dict[str, TunableValue] | None + status: Status + + class MockOptimizer(TrackBestOptimizer): """Mock optimizer to test the Environment API.""" @@ -42,6 +52,39 @@ def __init__( "float": lambda tunable: rnd.uniform(*tunable.range), "int": lambda tunable: rnd.randint(*(int(x) for x in tunable.range)), } + self._registered_scores: list[RegisteredScore] = [] + + @property + def registered_scores(self) -> list[RegisteredScore]: + """ + Return the list of registered scores. + + Notes + ----- + Used for testing and validation. + """ + return self._registered_scores + + def register( + self, + tunables: TunableGroups, + status: Status, + score: dict[str, TunableValue] | None = None, + ) -> dict[str, float] | None: + # Track the registered scores for testing and validation. + score = score or {} + # Almost the same as _get_scores, but we don't adjust the direction here. + scores: dict[str, TunableValue] = { + k: float(v) for k, v in score.items() if k in self._opt_targets and v is not None + } + self._registered_scores.append( + RegisteredScore( + config=tunables.copy(), + score=scores, + status=status, + ) + ) + return super().register(tunables, status, score) def bulk_register( self, diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index 1cd88fd5859..ee0d00757ae 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -100,8 +100,9 @@ def __init__( # pylint: disable=too-many-arguments self._optimizer = optimizer self._storage = storage self._root_env_config = root_env_config - self._last_trial_id = -1 self._ran_trials: list[Storage.Trial] = [] + self._longest_finished_trial_sequence_id = -1 + self._registered_trial_ids: set[int] = set() _LOG.debug("Scheduler instantiated: %s :: %s", self, config) @@ -242,8 +243,15 @@ def __exit__( self._in_context = False return False # Do not suppress exceptions - def start(self) -> None: - """Start the scheduling loop.""" + def _prepare_start(self) -> bool: + """ + Prepare the scheduler for starting. + + Notes + ----- + This method is called by the :py:meth:`Scheduler.start` method. + It is split out mostly to allow for easier unit testing/mocking. + """ assert self.experiment is not None _LOG.info( "START: Experiment: %s Env: %s Optimizer: %s", @@ -262,21 +270,43 @@ def start(self) -> None: is_warm_up: bool = self.optimizer.supports_preload if not is_warm_up: _LOG.warning("Skip pending trials and warm-up: %s", self.optimizer) + return is_warm_up + def start(self) -> None: + """Start the scheduling loop.""" + assert self.experiment is not None + is_warm_up = self._prepare_start() not_done: bool = True while not_done: - _LOG.info("Optimization loop: Last trial ID: %d", self._last_trial_id) - self.run_schedule(is_warm_up) - not_done = self.add_new_optimizer_suggestions() - self.assign_trial_runners( - self.experiment.pending_trials( - datetime.now(UTC), - running=False, - trial_runner_assigned=False, - ) - ) + not_done = self._execute_scheduling_step(is_warm_up) is_warm_up = False + def _execute_scheduling_step(self, is_warm_up: bool) -> bool: + """ + Perform a single scheduling step. + + Notes + ----- + This method is called by the :py:meth:`Scheduler.start` method. + It is split out mostly to allow for easier unit testing/mocking. + """ + assert self.experiment is not None + _LOG.info( + "Optimization loop: Longest finished trial sequence ID: %d", + self._longest_finished_trial_sequence_id, + ) + self.run_schedule(is_warm_up) + self.bulk_register_completed_trials() + not_done = self.add_new_optimizer_suggestions() + self.assign_trial_runners( + self.experiment.pending_trials( + datetime.now(UTC), + running=False, + trial_runner_assigned=False, + ) + ) + return not_done + def teardown(self) -> None: """ Tear down the TrialRunners/Environment(s). @@ -309,10 +339,65 @@ def load_tunable_config(self, config_id: int) -> TunableGroups: _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) return tunables.copy() + def bulk_register_completed_trials(self) -> None: + """ + Bulk register the most recent completed Trials in the Storage. + + Notes + ----- + This method is called after the Trials have been run (or started) and + the results have been recorded in the Storage by the TrialRunner(s). + + It has logic to handle straggler Trials that finish out of order so + should be usable by both + :py:class:`~mlos_bench.schedulers.SyncScheduler` and async Schedulers. + + See Also + -------- + Scheduler.start + The main loop of the Scheduler. + Storage.Experiment.load + Load the results of the Trials based on some filtering criteria. + Optimizer.bulk_register + Register the results of the Trials in the Optimizer. + """ + assert self.experiment is not None + # Load the results of the trials that have been run since the last time + # we queried the Optimizer. + # Note: We need to handle the case of straggler trials that finish out of order. + (trial_ids, configs, scores, status) = self.experiment.load( + last_trial_id=self._longest_finished_trial_sequence_id, + omit_registered_trial_ids=self._registered_trial_ids, + ) + _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids) + self.optimizer.bulk_register(configs, scores, status) + # Mark those trials as registered so we don't load them again. + self._registered_trial_ids.update(trial_ids) + # Update the longest finished trial sequence ID. + self._longest_finished_trial_sequence_id = max( + [ + self.experiment.get_longest_prefix_finished_trial_id(), + self._longest_finished_trial_sequence_id, + ], + default=self._longest_finished_trial_sequence_id, + ) + # Remove trial ids that are older than the longest finished trial sequence ID. + # This is an optimization to avoid a long list of trial ids to omit from + # the load() operation or a long list of trial ids to maintain in memory. + self._registered_trial_ids = { + trial_id + for trial_id in self._registered_trial_ids + if trial_id > self._longest_finished_trial_sequence_id + } + def add_new_optimizer_suggestions(self) -> bool: """ Optimizer part of the loop. + Asks the :py:class:`~.Optimizer` for new suggestions and adds them to + the queue. This method is called after the trials have been run and the + results have been loaded into the optimizer. + Load the results of the executed trials into the :py:class:`~.Optimizer`, suggest new configurations, and add them to the queue. @@ -323,16 +408,24 @@ def add_new_optimizer_suggestions(self) -> bool: The return value indicates whether the optimization process should continue to get suggestions from the Optimizer or not. See Also: :py:meth:`~.Scheduler.not_done`. - """ - assert self.experiment is not None - (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id) - _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids) - self.optimizer.bulk_register(configs, scores, status) - self._last_trial_id = max(trial_ids, default=self._last_trial_id) + Notes + ----- + Subclasses can override this method to implement a more sophisticated + scheduling policy using the information obtained from the Optimizer. + + See Also + -------- + Scheduler.not_done + The stopping conditions for the optimization process. + + Scheduler.bulk_register_completed_trials + Bulk register the most recent completed trials in the storage. + """ # Check if the optimizer has converged or not. not_done = self.not_done() if not_done: + # TODO: Allow scheduling multiple configs at once (e.g., in the case of idle workers). tunables = self.optimizer.suggest() self.add_trial_to_queue(tunables) return not_done diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 75a84bf0b2e..81e8148d26e 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -26,7 +26,7 @@ import logging from abc import ABCMeta, abstractmethod -from collections.abc import Iterator, Mapping +from collections.abc import Iterable, Iterator, Mapping from contextlib import AbstractContextManager as ContextManager from datetime import datetime from types import TracebackType @@ -324,27 +324,55 @@ def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]: Telemetry data. """ + @abstractmethod + def get_longest_prefix_finished_trial_id(self) -> int: + """ + Calculate the last trial ID for the experiment. + + This is used to determine the last trial ID that finished (failed or + successful) such that all Trials before it are also finished. + """ + @abstractmethod def load( self, last_trial_id: int = -1, + omit_registered_trial_ids: Iterable[int] | None = None, ) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]: """ Load (tunable values, benchmark scores, status) to warm-up the optimizer. - If `last_trial_id` is present, load only the data from the (completed) trials - that were scheduled *after* the given trial ID. Otherwise, return data from ALL - merged-in experiments and attempt to impute the missing tunable values. + If `last_trial_id` is present, load only the data from the + (:py:meth:`completed `) trials that were + added *after* the given trial ID. Otherwise, return data from + ALL merged-in experiments and attempt to impute the missing tunable + values. + + Additionally, if ``omit_registered_trial_ids`` is provided, omit the + trials matching those ids. + + The parameters together allow us to efficiently load data from + finished trials that we haven't registered with the Optimizer yet + for bulk registering. Parameters ---------- last_trial_id : int (Optional) Trial ID to start from. + omit_registered_trial_ids : Iterable[int] | None = None, + (Optional) List of trial IDs to omit. If None, load all trials + after ``last_trial_id``. Returns ------- (trial_ids, configs, scores, status) : ([int], [dict], [dict] | None, [Status]) Trial ids, Tunable values, benchmark scores, and status of the trials. + + See Also + -------- + Storage.Experiment.get_longest_prefix_finished_trial_id : + Get the last (registered) trial ID for the experiment. + Scheduler.add_new_optimizer_suggestions """ @abstractmethod diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index 97eb270c9d9..032cf9259d8 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -95,7 +95,7 @@ def get_trials( config_id=trial.config_id, ts_start=utcify_timestamp(trial.ts_start, origin="utc"), ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), - status=Status.from_str(trial.status), + status=Status.parse(trial.status), trial_runner_id=trial.trial_runner_id, ) for trial in trials.fetchall() diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index acc2a497b48..2fc31700242 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -8,7 +8,7 @@ import hashlib import logging -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from datetime import datetime from typing import Any, Literal @@ -153,13 +153,59 @@ def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]: for row in cur_telemetry.fetchall() ] + # TODO: Add a test for this method. + def get_longest_prefix_finished_trial_id(self) -> int: + with self._engine.connect() as conn: + # TODO: Do this in a single query? + + # Get the first (minimum) trial ID with an unfinished status. + first_unfinished_trial_id_stmt = ( + self._schema.trial.select() + .with_only_columns( + func.min(self._schema.trial.c.trial_id), + ) + .where( + self._schema.trial.c.exp_id == self._experiment_id, + func.not_( + self._schema.trial.c.status.in_( + [status.name for status in Status.completed_statuses()] + ) + ), + ) + ) + max_trial_id = conn.execute(first_unfinished_trial_id_stmt).scalar() + if max_trial_id is not None: + # Return one less than the first unfinished trial ID - it should be + # finished (or not exist, which is fine as a limit). + return int(max_trial_id) - 1 + + # No unfinished trials, so *all* trials are completed - get the + # largest completed trial ID. + last_finished_trial_id = ( + self._schema.trial.select() + .with_only_columns( + func.max(self._schema.trial.c.trial_id), + ) + .where( + self._schema.trial.c.exp_id == self._experiment_id, + self._schema.trial.c.status.in_( + [status.name for status in Status.completed_statuses()] + ), + ) + ) + max_trial_id = conn.execute(last_finished_trial_id).scalar() + if max_trial_id is not None: + return int(max_trial_id) + # Else no trials yet exist for this experiment. + return -1 + def load( self, last_trial_id: int = -1, + omit_registered_trial_ids: Iterable[int] | None = None, ) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]: - with self._engine.connect() as conn: - cur_trials = conn.execute( + stmt = ( self._schema.trial.select() .with_only_columns( self._schema.trial.c.trial_id, @@ -170,11 +216,7 @@ def load( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id > last_trial_id, self._schema.trial.c.status.in_( - [ - Status.SUCCEEDED.name, - Status.FAILED.name, - Status.TIMED_OUT.name, - ] + [status.name for status in Status.completed_statuses()] ), ) .order_by( @@ -182,13 +224,22 @@ def load( ) ) + # TODO: Add a test for this parameter. + + # Note: if we have a very large number of trials, this may encounter + # SQL text length limits, so we may need to chunk this. + if omit_registered_trial_ids is not None: + stmt = stmt.where(self._schema.trial.c.trial_id.notin_(omit_registered_trial_ids)) + + cur_trials = conn.execute(stmt) + trial_ids: list[int] = [] configs: list[dict[str, Any]] = [] scores: list[dict[str, Any] | None] = [] status: list[Status] = [] for trial in cur_trials.fetchall(): - stat = Status.from_str(trial.status) + stat = Status.parse(trial.status) status.append(stat) trial_ids.append(trial.trial_id) configs.append( @@ -272,7 +323,7 @@ def get_trial_by_id( config_id=trial.config_id, trial_runner_id=trial.trial_runner_id, opt_targets=self._opt_targets, - status=Status.from_str(trial.status), + status=Status.parse(trial.status), restoring=True, config=config, ) @@ -330,7 +381,7 @@ def pending_trials( config_id=trial.config_id, trial_runner_id=trial.trial_runner_id, opt_targets=self._opt_targets, - status=Status.from_str(trial.status), + status=Status.parse(trial.status), restoring=True, config=config, ) diff --git a/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py b/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py index ad8f9248acd..306fe00ffba 100644 --- a/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py @@ -7,6 +7,17 @@ import pytest +<<<<<<< HEAD +from mlos_bench.config.schemas.config_schemas import ConfigSchema +from mlos_bench.schedulers.base_scheduler import Scheduler +from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.tests.config import locate_config_examples +from mlos_bench.util import get_class_from_name + +_LOG = logging.getLogger(__name__) +_LOG.setLevel(logging.DEBUG) + +======= import mlos_bench.tests.optimizers.fixtures import mlos_bench.tests.storage.sql.fixtures from mlos_bench.config.schemas.config_schemas import ConfigSchema @@ -26,6 +37,7 @@ _LOG.setLevel(logging.DEBUG) # pylint: disable=redefined-outer-name +>>>>>>> refactor/mock-scheduler-and-tests # Get the set of configs to test. CONFIG_TYPE = "schedulers" @@ -43,6 +55,8 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]: ) assert configs +<<<<<<< HEAD +======= test_configs = locate_config_examples( BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, @@ -51,6 +65,7 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]: # assert test_configs configs.extend(test_configs) +>>>>>>> refactor/mock-scheduler-and-tests @pytest.mark.parametrize("config_path", configs) def test_load_scheduler_config_examples( diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-fields.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-fields.jsonc new file mode 100644 index 00000000000..d36559cf334 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-fields.jsonc @@ -0,0 +1,24 @@ +{ + "class": "mlos_bench.environments.mock_env.MockEnv", + "config": { + "mock_trial_data": { + "1": { + "run": { + // bad types + "status": null, + "metrics": [], + "exception": null, + "sleep": "1", + }, + // missing fields + "setup": {}, + "teardown": { + "status": "UNKNOWN", + "metrics": { + "unexpected": "field" + } + } + } + } + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-ids.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-ids.jsonc new file mode 100644 index 00000000000..400e557d0fa --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-ids.jsonc @@ -0,0 +1,13 @@ +{ + "class": "mlos_bench.environments.mock_env.MockEnv", + "config": { + "mock_trial_data": { + // invalid trial id + "trial_id_1": { + "run": { + "status": "FAILED" + } + } + } + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/unhandled/mock_env-trial-data-extras.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/unhandled/mock_env-trial-data-extras.jsonc new file mode 100644 index 00000000000..ecdf4cd0f51 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/unhandled/mock_env-trial-data-extras.jsonc @@ -0,0 +1,15 @@ +{ + "class": "mlos_bench.environments.mock_env.MockEnv", + "config": { + "mock_trial_data": { + "1": { + "new_phase": { + "status": "FAILED" + }, + "run": { + "expected": "property" + } + } + } + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/good/full/mock_env-full.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/good/full/mock_env-full.jsonc index a00f8ca60c0..a23971f0362 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/good/full/mock_env-full.jsonc +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/good/full/mock_env-full.jsonc @@ -25,6 +25,38 @@ "mock_env_metrics": [ "latency", "cost" - ] + ], + "mock_trial_data": { + "1": { + "setup": { + "sleep": 0.1 + }, + "status": { + "metrics": { + "latency": 0.2, + "cost": 0.3 + } + }, + "run": { + "sleep": 0.2, + "status": "SUCCEEDED", + "metrics": { + "latency": 0.1, + "cost": 0.2 + } + }, + "teardown": { + "sleep": 0.1 + } + }, + "2": { + "setup": { + "exception": "Some exception" + }, + "teardown": { + "exception": "Some other exception" + } + } + } } } diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test-cases/good/full/globals-with-schema.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/globals/test-cases/good/full/globals-with-schema.jsonc index 58a0a31bb36..4ed580e09a9 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test-cases/good/full/globals-with-schema.jsonc +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test-cases/good/full/globals-with-schema.jsonc @@ -10,5 +10,20 @@ "mysql": ["mysql-innodb", "mysql-myisam", "mysql-binlog", "mysql-hugepages"] }, "experiment_id": "ExperimentName", - "trial_id": 1 + "trial_id": 1, + + "mock_trial_data": { + "1": { + "setup": { + "sleep": 1 + }, + "run": { + "status": "SUCCEEDED", + "metrics": { + "score": 0.9, + "color": "green" + } + } + } + } } diff --git a/mlos_bench/mlos_bench/tests/environments/test_status.py b/mlos_bench/mlos_bench/tests/environments/test_status.py index 3c0a9bccf3c..785275825c0 100644 --- a/mlos_bench/mlos_bench/tests/environments/test_status.py +++ b/mlos_bench/mlos_bench/tests/environments/test_status.py @@ -51,16 +51,19 @@ def test_status_from_str_valid(input_str: str, expected_status: Status) -> None: Expected Status enum value. """ assert ( - Status.from_str(input_str) == expected_status + Status.parse(input_str) == expected_status ), f"Expected {expected_status} for input: {input_str}" # Check lowercase representation assert ( - Status.from_str(input_str.lower()) == expected_status + Status.parse(input_str.lower()) == expected_status ), f"Expected {expected_status} for input: {input_str.lower()}" + assert ( + Status.parse(expected_status) == expected_status + ), f"Expected {expected_status} for input: {expected_status}" if input_str.isdigit(): # Also test the numeric representation assert ( - Status.from_str(int(input_str)) == expected_status + Status.parse(int(input_str)) == expected_status ), f"Expected {expected_status} for input: {int(input_str)}" @@ -83,7 +86,7 @@ def test_status_from_str_invalid(invalid_input: Any) -> None: input. """ assert ( - Status.from_str(invalid_input) == Status.UNKNOWN + Status.parse(invalid_input) == Status.UNKNOWN ), f"Expected Status.UNKNOWN for invalid input: {invalid_input}" diff --git a/mlos_bench/mlos_bench/tests/schedulers/__init__.py b/mlos_bench/mlos_bench/tests/schedulers/__init__.py new file mode 100644 index 00000000000..4bc0076079f --- /dev/null +++ b/mlos_bench/mlos_bench/tests/schedulers/__init__.py @@ -0,0 +1,5 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""mlos_bench.tests.schedulers.""" diff --git a/mlos_bench/mlos_bench/tests/schedulers/conftest.py b/mlos_bench/mlos_bench/tests/schedulers/conftest.py new file mode 100644 index 00000000000..df6bd2776fd --- /dev/null +++ b/mlos_bench/mlos_bench/tests/schedulers/conftest.py @@ -0,0 +1,123 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""Pytest fixtures for mlos_bench.schedulers tests.""" +# pylint: disable=redefined-outer-name + +import json +import re + +import pytest +from pytest import FixtureRequest + +from mlos_bench.environments.mock_env import MockEnv +from mlos_bench.schedulers.trial_runner import TrialRunner +from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.tunables.tunable_groups import TunableGroups + +NUM_TRIAL_RUNNERS = 4 + + +@pytest.fixture +def mock_env_config() -> dict: + """A config for a MockEnv with mock_trial_data.""" + return { + "name": "Test MockEnv With Explicit Mock Trial Data", + "class": "mlos_bench.environments.mock_env.MockEnv", + "config": { + # Reference the covariant groups from the `tunable_groups` fixture. + # See Also: + # - mlos_bench/tests/conftest.py + # - mlos_bench/tests/tunable_groups_fixtures.py + "tunable_params": ["provision", "boot", "kernel"], + "mock_env_seed": -1, + "mock_env_range": [0, 10], + "mock_env_metrics": ["score"], + # TODO: Add more mock trial data here: + "mock_trial_data": { + "1": { + "run": { + "sleep": 0.15, + "status": "SUCCEEDED", + "metrics": { + "score": 1.0, + }, + }, + }, + "2": { + "run": { + "sleep": 0.2, + "status": "SUCCEEDED", + "metrics": { + "score": 2.0, + }, + }, + }, + "3": { + "run": { + "sleep": 0.1, + "status": "SUCCEEDED", + "metrics": { + "score": 3.0, + }, + }, + }, + }, + }, + } + + +@pytest.fixture +def global_config(request: FixtureRequest) -> dict: + """A global config for a MockEnv.""" + test_name = request.node.name + test_name = re.sub(r"[^a-zA-Z0-9]", "_", test_name) + experiment_id = f"TestExperiment-{test_name}" + return { + "experiment_id": experiment_id, + "trial_id": 1, + } + + +@pytest.fixture +def mock_env_json_config(mock_env_config: dict) -> str: + """A JSON string of the mock_env_config.""" + return json.dumps(mock_env_config) + + +@pytest.fixture +def mock_env( + mock_env_json_config: str, + tunable_groups: TunableGroups, + global_config: dict, +) -> MockEnv: + """A fixture to create a MockEnv instance using the mock_env_json_config.""" + config_loader_service = ConfigPersistenceService() + mock_env = config_loader_service.load_environment( + mock_env_json_config, + tunable_groups, + service=config_loader_service, + global_config=global_config, + ) + assert isinstance(mock_env, MockEnv) + return mock_env + + +@pytest.fixture +def trial_runners( + mock_env_json_config: str, + tunable_groups: TunableGroups, + global_config: dict, +) -> list[TrialRunner]: + """A fixture to create a list of TrialRunner instances using the + mock_env_json_config. + """ + config_loader_service = ConfigPersistenceService(global_config=global_config) + return TrialRunner.create_from_json( + config_loader=config_loader_service, + env_json=mock_env_json_config, + tunable_groups=tunable_groups, + num_trial_runners=NUM_TRIAL_RUNNERS, + global_config=global_config, + ) diff --git a/mlos_bench/mlos_bench/tests/schedulers/test_scheduler.py b/mlos_bench/mlos_bench/tests/schedulers/test_scheduler.py new file mode 100644 index 00000000000..9a2cc1dbaec --- /dev/null +++ b/mlos_bench/mlos_bench/tests/schedulers/test_scheduler.py @@ -0,0 +1,213 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""Unit tests for :py:class:`mlos_bench.schedulers` and their internals.""" + +import sys +from logging import warning + +import pytest + +import mlos_bench.tests.optimizers.fixtures as optimizers_fixtures +import mlos_bench.tests.storage.sql.fixtures as sql_storage_fixtures +from mlos_bench.environments.mock_env import MockEnv +from mlos_bench.optimizers.mock_optimizer import MockOptimizer +from mlos_bench.schedulers.base_scheduler import Scheduler +from mlos_bench.schedulers.trial_runner import TrialRunner +from mlos_bench.storage.base_trial_data import TrialData +from mlos_bench.storage.sql.storage import SqlStorage +from mlos_core.tests import get_all_concrete_subclasses + +mock_opt = optimizers_fixtures.mock_opt +sqlite_storage = sql_storage_fixtures.sqlite_storage + +DEBUG_WARNINGS_ENABLED = False + +# pylint: disable=redefined-outer-name + + +def create_scheduler( + scheduler_type: type[Scheduler], + trial_runners: list[TrialRunner], + mock_opt: MockOptimizer, + sqlite_storage: SqlStorage, + global_config: dict, +) -> Scheduler: + """Create a Scheduler instance using trial_runners, mock_opt, and sqlite_storage + fixtures. + """ + env = trial_runners[0].environment + assert isinstance(env, MockEnv), "Environment is not a MockEnv instance." + max_trials = max(trial_id for trial_id in env.mock_trial_data.keys()) + max_trials = min(max_trials, mock_opt.max_suggestions) + + return scheduler_type( + config={ + "max_trials": max_trials, + }, + global_config=global_config, + trial_runners=trial_runners, + optimizer=mock_opt, + storage=sqlite_storage, + root_env_config="", + ) + + +def debug_warn(*args: object) -> None: + """Optionally issue warnings for debugging.""" + if DEBUG_WARNINGS_ENABLED: + warning(*args) + + +def mock_opt_has_registered_trial_score( + mock_opt: MockOptimizer, + trial_data: TrialData, +) -> bool: + """Check that the MockOptimizer has registered a given MockTrialData.""" + # pylint: disable=consider-using-any-or-all + # Split out for easier debugging. + for registered_score in mock_opt.registered_scores: + match = True + if registered_score.status != trial_data.status: + match = False + debug_warn( + "Registered status %s does not match trial data %s.", + registered_score.status, + trial_data.results_dict, + ) + elif registered_score.score != trial_data.results_dict: + debug_warn( + "Registered score %s does not match trial data %s.", + registered_score.score, + trial_data.results_dict, + ) + match = False + elif registered_score.config.get_param_values() != trial_data.tunable_config.config_dict: + debug_warn( + "Registered config %s does not match trial data %s.", + registered_score.config.get_param_values(), + trial_data.results_dict, + ) + match = False + if match: + return True + return False + + +scheduler_classes = get_all_concrete_subclasses( + Scheduler, # type: ignore[type-abstract] + pkg_name="mlos_bench", +) +assert scheduler_classes, "No Scheduler classes found in mlos_bench." + + +@pytest.mark.parametrize( + "scheduler_class", + scheduler_classes, +) +@pytest.mark.skipif( + sys.platform == "win32", + reason="Skipping test on Windows - SQLite storage is not accessible in parallel tests there.", +) +def test_scheduler_with_mock_trial_data( + scheduler_class: type[Scheduler], + trial_runners: list[TrialRunner], + mock_opt: MockOptimizer, + sqlite_storage: SqlStorage, + global_config: dict, +) -> None: + """ + Full integration test for Scheduler: runs trials, checks storage, optimizer + registration, and internal bookkeeping. + """ + # pylint: disable=too-many-locals + + # Create the scheduler. + scheduler = create_scheduler( + scheduler_class, + trial_runners, + mock_opt, + sqlite_storage, + global_config, + ) + + root_env = scheduler.root_environment + experiment_id = root_env.experiment_id + assert isinstance(root_env, MockEnv), f"Root environment {root_env} is not a MockEnv." + assert root_env.mock_trial_data, "No mock trial data found in root environment." + + # Run the scheduler + with scheduler: + scheduler.start() + scheduler.teardown() + + # Now check the overall results. + ran_trials = {trial.trial_id for trial in scheduler.ran_trials} + assert ( + experiment_id in sqlite_storage.experiments + ), f"Experiment {experiment_id} not found in storage." + exp_data = sqlite_storage.experiments[experiment_id] + + for mock_trial_data in root_env.mock_trial_data.values(): + trial_id = mock_trial_data.trial_id + + # Check the bookkeeping for ran_trials. + assert trial_id in ran_trials, f"Trial {trial_id} not found in Scheduler.ran_trials." + + # Check the results in storage. + assert trial_id in exp_data.trials, f"Trial {trial_id} not found in storage." + trial_data = exp_data.trials[trial_id] + + # Check the results. + metrics = mock_trial_data.run.metrics + if metrics: + for result_key, result_value in metrics.items(): + assert ( + result_key in trial_data.results_dict + ), f"Result {result_key} not found in storage for trial {trial_data}." + assert ( + trial_data.results_dict[result_key] == result_value + ), f"Result value for {result_key} does not match expected value." + # TODO: Should we check the reverse - no extra metrics were registered? + # else: metrics weren't explicit in the mock trial data, so we only + # check that a score was stored for the optimization target, but that's + # good to do regardless + for opt_target in mock_opt.targets: + assert ( + opt_target in trial_data.results_dict + ), f"Result column {opt_target} not found in storage." + assert ( + trial_data.results_dict[opt_target] is not None + ), f"Result value for {opt_target} is None." + + # Check that the appropriate sleeps occurred. + trial_time_lb = 0.0 + trial_time_lb += mock_trial_data.setup.sleep or 0 + trial_time_lb += mock_trial_data.run.sleep or 0 + trial_time_lb += mock_trial_data.status.sleep or 0 + trial_time_lb += mock_trial_data.teardown.sleep or 0 + assert trial_data.ts_end is not None, f"Trial {trial_id} has no end time." + trial_duration = trial_data.ts_end - trial_data.ts_start + trial_dur_secs = trial_duration.total_seconds() + assert ( + trial_dur_secs >= trial_time_lb + ), f"Trial {trial_id} took less time ({trial_dur_secs}) than expected ({trial_time_lb}). " + + # Check that the trial status matches what we expected. + assert ( + trial_data.status == mock_trial_data.run.status + ), f"Trial {trial_id} status {trial_data.status} was not {mock_trial_data.run.status}." + + # TODO: Check the trial status telemetry. + + # Check the optimizer registration. + assert mock_opt_has_registered_trial_score( + mock_opt, + trial_data, + ), f"Trial {trial_id} was not registered in the optimizer." + + # TODO: And check the intermediary results. + # 4. Check the bookkeeping for add_new_optimizer_suggestions and _last_trial_id. + # This last part may require patching and intercepting during the start() + # loop to validate in-progress book keeping instead of just overall. diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index e07cf80c70a..f82833dd64f 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -3,7 +3,8 @@ # Licensed under the MIT License. # """Unit tests for the storage subsystem.""" -from datetime import datetime, tzinfo +from datetime import datetime, timedelta, tzinfo +from random import random import pytest from pytz import UTC @@ -157,3 +158,265 @@ def test_exp_trial_pending_3( assert status == [Status.FAILED, Status.SUCCEEDED] assert tunable_groups.copy().assign(configs[0]).reset() == trial_fail.tunables assert tunable_groups.copy().assign(configs[1]).reset() == trial_succ.tunables + + +def test_empty_get_longest_prefix_finished_trial_id( + storage: Storage, + exp_storage: Storage.Experiment, +) -> None: + """Test that the longest prefix of finished trials is empty when no trials are + present. + """ + assert not storage.experiments[ + exp_storage.experiment_id + ].trials, "Expected no trials in the experiment." + + # Retrieve the longest prefix of finished trials when no trials are present + longest_prefix_id = exp_storage.get_longest_prefix_finished_trial_id() + + # Assert that the longest prefix is empty + assert ( + longest_prefix_id == -1 + ), f"Expected longest prefix to be -1, but got {longest_prefix_id}" + + +def test_sync_success_get_longest_prefix_finished_trial_id( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, +) -> None: + """Test that the longest prefix of finished trials is returned correctly when all + trial are finished. + """ + timestamp = datetime.now(UTC) + config = {} + metrics = {metric: random() for metric in exp_storage.opt_targets} + + # Create several trials + trials = [exp_storage.new_trial(tunable_groups, config=config) for _ in range(0, 4)] + + # Mark some trials at the beginning and end as finished + trials[0].update(Status.SUCCEEDED, timestamp + timedelta(minutes=1), metrics=metrics) + trials[1].update(Status.FAILED, timestamp + timedelta(minutes=2), metrics=metrics) + trials[2].update(Status.TIMED_OUT, timestamp + timedelta(minutes=3), metrics=metrics) + trials[3].update(Status.CANCELED, timestamp + timedelta(minutes=4), metrics=metrics) + + # Retrieve the longest prefix of finished trials starting from trial_id 1 + longest_prefix_id = exp_storage.get_longest_prefix_finished_trial_id() + + # Assert that the longest prefix includes only the first three trials + assert longest_prefix_id == trials[3].trial_id, ( + f"Expected longest prefix to end at trial_id {trials[3].trial_id}, " + f"but got {longest_prefix_id}" + ) + + +def test_async_get_longest_prefix_finished_trial_id( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, +) -> None: + """Test that the longest prefix of finished trials is returned correctly when trial + finish out of order. + """ + timestamp = datetime.now(UTC) + config = {} + metrics = {metric: random() for metric in exp_storage.opt_targets} + + # Create several trials + trials = [exp_storage.new_trial(tunable_groups, config=config) for _ in range(0, 10)] + + # Mark some trials at the beginning and end as finished + trials[0].update(Status.SUCCEEDED, timestamp + timedelta(minutes=1), metrics=metrics) + trials[1].update(Status.FAILED, timestamp + timedelta(minutes=2), metrics=metrics) + trials[2].update(Status.TIMED_OUT, timestamp + timedelta(minutes=3), metrics=metrics) + trials[3].update(Status.CANCELED, timestamp + timedelta(minutes=4), metrics=metrics) + # Leave trials[3] to trials[7] as PENDING + trials[9].update(Status.SUCCEEDED, timestamp + timedelta(minutes=5), metrics=metrics) + + # Retrieve the longest prefix of finished trials starting from trial_id 1 + longest_prefix_id = exp_storage.get_longest_prefix_finished_trial_id() + + # Assert that the longest prefix includes only the first three trials + assert longest_prefix_id == trials[3].trial_id, ( + f"Expected longest prefix to end at trial_id {trials[3].trial_id}, " + f"but got {longest_prefix_id}" + ) + + +# TODO: Can we simplify this to use something like SyncScheduler and +# bulk_register_completed_trials? +# TODO: Update to use MockScheduler +def test_exp_load_async( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, +) -> None: + """ + Test the ``omit_registered_trial_ids`` argument of the ``Experiment.load()`` method. + + Create several trials with mixed statuses (PENDING and completed). + Verify that completed trials included in a local set of registered configs + are omitted from the `load` operation. + """ + # pylint: disable=too-many-locals,too-many-statements + + last_trial_id = exp_storage.get_longest_prefix_finished_trial_id() + assert last_trial_id == -1, "Expected no trials in the experiment." + registered_trial_ids: set[int] = set() + + # Load trials, omitting registered ones + trial_ids, configs, scores, status = exp_storage.load( + last_trial_id=last_trial_id, + omit_registered_trial_ids=registered_trial_ids, + ) + + assert trial_ids == [] + assert configs == [] + assert scores == [] + assert status == [] + + # Create trials with mixed statuses + trial_1_success = exp_storage.new_trial(tunable_groups) + trial_2_failed = exp_storage.new_trial(tunable_groups) + trial_3_pending = exp_storage.new_trial(tunable_groups) + trial_4_timedout = exp_storage.new_trial(tunable_groups) + trial_5_pending = exp_storage.new_trial(tunable_groups) + + # Update statuses for completed trials + trial_1_success.update(Status.SUCCEEDED, datetime.now(UTC), {"score": 95.0}) + trial_2_failed.update(Status.FAILED, datetime.now(UTC), {"score": -1}) + trial_4_timedout.update(Status.TIMED_OUT, datetime.now(UTC), {"score": -1}) + + # Now evaluate some different sequences of loading trials by simulating what + # we expect a Scheduler to do. + # See Also: Scheduler.add_new_optimizer_suggestions() + + trial_ids, configs, scores, status = exp_storage.load( + last_trial_id=last_trial_id, + omit_registered_trial_ids=registered_trial_ids, + ) + + # Verify that all completed trials are returned. + completed_trials = [ + trial_1_success, + trial_2_failed, + trial_4_timedout, + ] + assert trial_ids == [trial.trial_id for trial in completed_trials] + assert len(configs) == len(completed_trials) + assert status == [trial.status for trial in completed_trials] + + last_trial_id = exp_storage.get_longest_prefix_finished_trial_id() + assert last_trial_id == trial_2_failed.trial_id, ( + f"Expected longest prefix to end at trial_id {trial_2_failed.trial_id}, " + f"but got {last_trial_id}" + ) + registered_trial_ids |= {completed_trial.trial_id for completed_trial in completed_trials} + registered_trial_ids = {i for i in registered_trial_ids if i > last_trial_id} + + # Create some more trials and update their statuses. + # Note: we are leaving some trials in the middle in the PENDING state. + trial_6_canceled = exp_storage.new_trial(tunable_groups) + trial_7_success2 = exp_storage.new_trial(tunable_groups) + trial_6_canceled.update(Status.CANCELED, datetime.now(UTC), {"score": -1}) + trial_7_success2.update(Status.SUCCEEDED, datetime.now(UTC), {"score": 90.0}) + + # Load trials, omitting registered ones + trial_ids, configs, scores, status = exp_storage.load( + last_trial_id=last_trial_id, + omit_registered_trial_ids=registered_trial_ids, + ) + # Verify that only unregistered completed trials are returned + completed_trials = [ + trial_6_canceled, + trial_7_success2, + ] + assert trial_ids == [trial.trial_id for trial in completed_trials] + assert len(configs) == len(completed_trials) + assert status == [trial.status for trial in completed_trials] + + # Update our tracking of registered trials + last_trial_id = exp_storage.get_longest_prefix_finished_trial_id() + # Should still be the same as before since we haven't adjusted the PENDING + # trials at the beginning yet. + assert last_trial_id == trial_2_failed.trial_id, ( + f"Expected longest prefix to end at trial_id {trial_2_failed.trial_id}, " + f"but got {last_trial_id}" + ) + registered_trial_ids |= {completed_trial.trial_id for completed_trial in completed_trials} + registered_trial_ids = {i for i in registered_trial_ids if i > last_trial_id} + + trial_ids, configs, scores, status = exp_storage.load( + last_trial_id=last_trial_id, + omit_registered_trial_ids=registered_trial_ids, + ) + + # Verify that only unregistered completed trials are returned + completed_trials = [] + assert trial_ids == [trial.trial_id for trial in completed_trials] + assert len(configs) == len(completed_trials) + assert status == [trial.status for trial in completed_trials] + + # Now update the PENDING trials to be TIMED_OUT. + trial_3_pending.update(Status.TIMED_OUT, datetime.now(UTC), {"score": -1}) + + # Load trials, omitting registered ones + trial_ids, configs, scores, status = exp_storage.load( + last_trial_id=last_trial_id, + omit_registered_trial_ids=registered_trial_ids, + ) + + # Verify that only unregistered completed trials are returned + completed_trials = [ + trial_3_pending, + ] + assert trial_ids == [trial.trial_id for trial in completed_trials] + assert len(configs) == len(completed_trials) + assert status == [trial.status for trial in completed_trials] + + # Update our tracking of registered trials + last_trial_id = exp_storage.get_longest_prefix_finished_trial_id() + assert last_trial_id == trial_4_timedout.trial_id, ( + f"Expected longest prefix to end at trial_id {trial_4_timedout.trial_id}, " + f"but got {last_trial_id}" + ) + registered_trial_ids |= {completed_trial.trial_id for completed_trial in completed_trials} + registered_trial_ids = {i for i in registered_trial_ids if i > last_trial_id} + + # Load trials, omitting registered ones + trial_ids, configs, scores, status = exp_storage.load( + last_trial_id=last_trial_id, + omit_registered_trial_ids=registered_trial_ids, + ) + # Verify that only unregistered completed trials are returned + completed_trials = [] + assert trial_ids == [trial.trial_id for trial in completed_trials] + assert len(configs) == len(completed_trials) + assert status == [trial.status for trial in completed_trials] + # And that the longest prefix is still the same. + assert last_trial_id == trial_4_timedout.trial_id, ( + f"Expected longest prefix to end at trial_id {trial_4_timedout.trial_id}, " + f"but got {last_trial_id}" + ) + + # Mark the last trial as finished. + trial_5_pending.update(Status.SUCCEEDED, datetime.now(UTC), {"score": 95.0}) + # Load trials, omitting registered ones + trial_ids, configs, scores, status = exp_storage.load( + last_trial_id=last_trial_id, + omit_registered_trial_ids=registered_trial_ids, + ) + # Verify that only unregistered completed trials are returned + completed_trials = [ + trial_5_pending, + ] + assert trial_ids == [trial.trial_id for trial in completed_trials] + assert len(configs) == len(completed_trials) + assert status == [trial.status for trial in completed_trials] + # And that the longest prefix is now the last trial. + last_trial_id = exp_storage.get_longest_prefix_finished_trial_id() + assert last_trial_id == trial_7_success2.trial_id, ( + f"Expected longest prefix to end at trial_id {trial_7_success2.trial_id}, " + f"but got {last_trial_id}" + ) + registered_trial_ids |= {completed_trial.trial_id for completed_trial in completed_trials} + registered_trial_ids = {i for i in registered_trial_ids if i > last_trial_id} + assert registered_trial_ids == set() diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 0bebeeff824..db6dc5fa2e3 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -30,7 +30,7 @@ # pylint: disable=redefined-outer-name -@pytest.fixture +@pytest.fixture(scope="function") def sqlite_storage() -> Generator[SqlStorage]: """ Fixture for file based SQLite storage in a temporary directory. diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index e22b045f1af..90bf84e7bb4 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -24,14 +24,19 @@ def _trial_ids(trials: Iterator[Storage.Trial]) -> set[int]: return {t.trial_id for t in trials} -def test_schedule_trial( +def test_storage_schedule( storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups, ) -> None: # pylint: disable=too-many-locals,too-many-statements - """Schedule several trials for future execution and retrieve them later at certain - timestamps. + """ + Test some storage functions that schedule several trials for future execution and + retrieve them later at certain timestamps. + + Notes + ----- + This doesn't actually test the Scheduler. """ timestamp = datetime.now(UTC) timedelta_1min = timedelta(minutes=1)