diff --git a/.vscode/settings.json b/.vscode/settings.json
index a01d1d6418f..f29dec766e2 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -8,6 +8,7 @@
"**/__pycache__/": true,
"**/node_modules/": true,
"**/*.egg-info": true,
+ "mlos_*/build/": true,
"doc/source/autoapi/": true,
"doc/build/doctrees/": true,
"doc/build/html/": true,
diff --git a/mlos_bench/mlos_bench/config/schemas/cli/globals-schema.json b/mlos_bench/mlos_bench/config/schemas/cli/globals-schema.json
index 015b4a6e62c..39e60e3249b 100644
--- a/mlos_bench/mlos_bench/config/schemas/cli/globals-schema.json
+++ b/mlos_bench/mlos_bench/config/schemas/cli/globals-schema.json
@@ -24,6 +24,9 @@
},
"optimization_targets": {
"$ref": "./common-defs-subschemas.json#/$defs/optimization_targets"
+ },
+ "mock_trial_data": {
+ "$ref": "../environments/mock-env-subschema.json#/$defs/mock_trial_data"
}
},
"additionalProperties": {
diff --git a/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json b/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json
index cb2de6c719f..b453c8573aa 100644
--- a/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json
+++ b/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json
@@ -3,6 +3,136 @@
"$id": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/environments/mock-env-subschema.json",
"title": "mlos_bench MockEnv config",
"description": "Config instance for a mlos_bench MockEnv",
+
+ "$defs": {
+ "mock_trial_common_phase_data": {
+ "type": "object",
+ "properties": {
+ "sleep": {
+ "type": "number",
+ "description": "Optional time to sleep (in seconds) before returning from this phase of the trial.",
+ "examples": [0, 0.1, 0.5, 1, 2],
+ "minimum": 0,
+ "maximum": 60
+ },
+ "exception": {
+ "type": "string",
+ "description": "Optional exception message to raise during phase."
+ }
+ }
+ },
+ "mock_trial_status_run_phase_data": {
+ "type": "object",
+ "properties": {
+ "status": {
+ "description": "The status to report for this phase of the trial. Default is phase dependent.",
+ "enum": [
+ "UNKNOWN",
+ "PENDING",
+ "READY",
+ "RUNNING",
+ "SUCCEEDED",
+ "CANCELED",
+ "FAILED",
+ "TIMED_OUT"
+ ]
+ },
+ "metrics": {
+ "type": "object",
+ "description": "A dictionary of metrics for this phase of the trial.",
+ "additionalProperties": {
+ "type": [
+ "number",
+ "string",
+ "boolean"
+ ],
+ "description": "The value of the metric."
+ },
+ "examples": [
+ {
+ "score": 0.95,
+ "color": "green"
+ },
+ {
+ "accuracy": 0.85,
+ "loss": 0.15
+ }
+ ],
+ "minProperties": 0
+ }
+ }
+ },
+ "mock_trial_data_item": {
+ "description": "Mock data for a single trial, split by phase",
+ "type": "object",
+ "properties": {
+ "run": {
+ "description": "A dictionary of trial data for the run phase.",
+ "type": "object",
+ "allOf": [
+ {
+ "$ref": "#/$defs/mock_trial_common_phase_data"
+ },
+ {
+ "$ref": "#/$defs/mock_trial_status_run_phase_data"
+ }
+ ],
+ "minProperties": 1,
+ "unevaluatedProperties": false
+ },
+ "status": {
+ "description": "A dictionary of trial data for the status phase.",
+ "type": "object",
+ "allOf": [
+ {
+ "$ref": "#/$defs/mock_trial_common_phase_data"
+ },
+ {
+ "$ref": "#/$defs/mock_trial_status_run_phase_data"
+ }
+ ],
+ "minProperties": 1,
+ "unevaluatedProperties": false
+ },
+ "setup": {
+ "description": "A dictionary of trial data for the setup phase.",
+ "type": "object",
+ "allOf": [
+ {
+ "$ref": "#/$defs/mock_trial_common_phase_data"
+ }
+ ],
+ "minProperties": 1,
+ "unevaluatedProperties": false
+ },
+ "teardown": {
+ "description": "A dictionary of trial data for the teardown phase.",
+ "type": "object",
+ "allOf": [
+ {
+ "$ref": "#/$defs/mock_trial_common_phase_data"
+ }
+ ],
+ "minProperties": 1,
+ "unevaluatedProperties": false
+ }
+ },
+ "unevaluatedProperties": false,
+ "minProperties": 1
+ },
+ "mock_trial_data": {
+ "description": "A set of mock trial data to use for testing, keyed by trial id. Used by MockEnv.",
+ "type": "object",
+ "patternProperties": {
+ "^[1-9][0-9]*$": {
+ "$ref": "#/$defs/mock_trial_data_item"
+ }
+ },
+ "unevaluatedProperties": false,
+ "minProperties": 1
+ }
+ },
+
"type": "object",
"properties": {
"class": {
@@ -42,6 +172,9 @@
},
"minItems": 1,
"uniqueItems": true
+ },
+ "mock_trial_data": {
+ "$ref": "#/$defs/mock_trial_data"
}
}
}
diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py
index 094085c78b5..fe40025f95d 100644
--- a/mlos_bench/mlos_bench/environments/base_environment.py
+++ b/mlos_bench/mlos_bench/environments/base_environment.py
@@ -363,6 +363,92 @@ def parameters(self) -> dict[str, TunableValue]:
"""
return self._params.copy()
+ @property
+ def current_trial_id(self) -> int:
+ """
+ Get the current trial ID.
+
+ This value can be used in scripts or environment variables to help
+ identify the Trial this Environment is currently running.
+
+ Returns
+ -------
+ trial_id : int
+ The current trial ID.
+
+ Notes
+ -----
+ This method is used to identify the current trial ID for the environment.
+ It is expected to be called *after* the base
+ :py:meth:`Environment.setup` method has been called and parameters have
+ been assigned.
+ """
+ val = self._params["trial_id"]
+ assert isinstance(val, int), (
+ "Expected trial_id to be an int, but got %s (type %s): %s",
+ val,
+ type(val),
+ self._params,
+ )
+ return val
+
+ @property
+ def trial_runner_id(self) -> int:
+ """
+ Get the ID of the :py:class:`~.mlos_bench.schedulers.trial_runner.TrialRunner`
+ for this Environment.
+
+ This value can be used in scripts or environment variables to help
+ identify the TrialRunner for this Environment.
+
+ Returns
+ -------
+ trial_runner_id : int
+ The trial runner ID.
+
+ Notes
+ -----
+ This shouldn't change during the lifetime of the Environment since each
+ Environment is assigned to a single TrialRunner.
+ """
+ val = self._params["trial_runner_id"]
+ assert isinstance(val, int), (
+ "Expected trial_runner_id to be an int, but got %s (type %s)",
+ val,
+ type(val),
+ )
+ return val
+
+ @property
+ def experiment_id(self) -> str:
+ """
+ Get the ID of the experiment.
+
+ This value can be used in scripts or environment variables to help
+ identify the TrialRunner for this Environment.
+
+ Returns
+ -------
+ experiment_id : str
+ The ID of the experiment.
+
+ Notes
+ -----
+ This value comes from the globals config or ``mlos_bench`` CLI arguments
+ in the experiment setup.
+
+ See Also
+ --------
+ mlos_bench.config : documentation on the configuration system
+ """
+ val = self._params["experiment_id"]
+ assert isinstance(val, str), (
+ "Expected experiment_id to be an int, but got %s (type %s)",
+ val,
+ type(val),
+ )
+ return val
+
def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool:
"""
Set up a new benchmark environment, if necessary. This method must be
diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py
index ac6d9b7f001..c565b1adc66 100644
--- a/mlos_bench/mlos_bench/environments/mock_env.py
+++ b/mlos_bench/mlos_bench/environments/mock_env.py
@@ -6,6 +6,9 @@
import logging
import random
+import time
+from copy import deepcopy
+from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -21,6 +24,131 @@
_LOG = logging.getLogger(__name__)
+@dataclass
+class MockTrialPhaseData:
+ """Mock trial data for a specific phase of a trial."""
+
+ phase: str
+ """Phase of the trial data (e.g., setup, run, status, teardown)."""
+
+ status: Status
+ """Status response for the phase."""
+
+ metrics: dict[str, TunableValue] | None = None
+ """Metrics response for the phase."""
+
+ sleep: float | None = 0.0
+ """Optional sleep time in seconds to simulate phase execution time."""
+
+ exception: str | None = None
+ """Message of an exception to raise for the phase."""
+
+ @staticmethod
+ def from_dict(phase: str, data: dict | None) -> "MockTrialPhaseData":
+ """
+ Create a MockTrialPhaseData instance from a dictionary.
+
+ Parameters
+ ----------
+ phase : str
+ Phase of the trial data.
+ data : dict | None
+ Dictionary containing the phase data.
+
+ Returns
+ -------
+ MockTrialPhaseData
+ Instance of MockTrialPhaseData.
+ """
+ data = data or {}
+ assert phase in {"setup", "run", "status", "teardown"}, f"Invalid phase: {phase}"
+ if phase in {"teardown", "status"}:
+ # setup/teardown phase is not expected to have metrics or status.
+ assert "metrics" not in data, f"Unexpected metrics data in {phase} phase: {data}"
+ assert "status" not in data, f"Unexpected status data in {phase} phase: {data}"
+ if "sleep" in data:
+ assert isinstance(
+ data["sleep"], (int, float)
+ ), f"Invalid sleep in {phase} phase: {data}"
+ assert 60 >= data["sleep"] >= 0, f"Invalid sleep time in {phase} phase: {data}"
+ if "metrics" in data:
+ assert isinstance(data["metrics"], dict), f"Invalid metrics in {phase} phase: {data}"
+ default_phases = {
+ "run": Status.SUCCEEDED,
+ # FIXME: this causes issues if we report RUNNING instead of READY
+ "status": Status.READY,
+ }
+ status = Status.parse(data.get("status", default_phases.get(phase, Status.UNKNOWN)))
+ return MockTrialPhaseData(
+ phase=phase,
+ status=status,
+ metrics=data.get("metrics"),
+ sleep=data.get("sleep"),
+ exception=data.get("exception"),
+ )
+
+
+@dataclass
+class MockTrialData:
+ """Mock trial data for a specific trial ID."""
+
+ trial_id: int
+ """Trial ID for the mock trial data."""
+ setup: MockTrialPhaseData
+ """Setup phase data for the trial."""
+ run: MockTrialPhaseData
+ """Run phase data for the trial."""
+ status: MockTrialPhaseData
+ """Status phase data for the trial."""
+ teardown: MockTrialPhaseData
+ """Teardown phase data for the trial."""
+
+ @staticmethod
+ def from_dict(trial_id: int, data: dict) -> "MockTrialData":
+ """
+ Create a MockTrialData instance from a dictionary.
+
+ Parameters
+ ----------
+ trial_id : int
+ Trial ID for the mock trial data.
+ data : dict
+ Dictionary containing the trial data.
+
+ Returns
+ -------
+ MockTrialData
+ Instance of MockTrialData.
+ """
+ return MockTrialData(
+ trial_id=trial_id,
+ setup=MockTrialPhaseData.from_dict("setup", data.get("setup")),
+ run=MockTrialPhaseData.from_dict("run", data.get("run")),
+ status=MockTrialPhaseData.from_dict("status", data.get("status")),
+ teardown=MockTrialPhaseData.from_dict("teardown", data.get("teardown")),
+ )
+
+ @staticmethod
+ def load_mock_trial_data(mock_trial_data: dict) -> dict[int, "MockTrialData"]:
+ """
+ Load mock trial data from a dictionary.
+
+ Parameters
+ ----------
+ mock_trial_data : dict
+ Dictionary containing mock trial data.
+
+ Returns
+ -------
+ dict[int, MockTrialData]
+ Dictionary of mock trial data keyed by trial ID.
+ """
+ return {
+ int(trial_id): MockTrialData.from_dict(trial_id=int(trial_id), data=trial_data)
+ for trial_id, trial_data in mock_trial_data.items()
+ }
+
+
class MockEnv(Environment):
"""Scheduler-side environment to mock the benchmark results."""
@@ -55,6 +183,19 @@ def __init__( # pylint: disable=too-many-arguments
service: Service
An optional service object. Not used by this class.
"""
+ # First allow merging mock_trial_data from the global_config into the
+ # config so we can check it against the JSON schema for expected data
+ # types.
+ if global_config and "mock_trial_data" in global_config:
+ mock_trial_data = global_config["mock_trial_data"]
+ if not isinstance(mock_trial_data, dict):
+ raise ValueError(f"Invalid mock_trial_data in global_config: {mock_trial_data}")
+ # Merge the mock trial data into the config.
+ config["mock_trial_data"] = {
+ **config.get("mock_trial_data", {}),
+ **mock_trial_data,
+ }
+
super().__init__(
name=name,
config=config,
@@ -62,6 +203,9 @@ def __init__( # pylint: disable=too-many-arguments
tunables=tunables,
service=service,
)
+ self._mock_trial_data = MockTrialData.load_mock_trial_data(
+ self.config.get("mock_trial_data", {})
+ )
seed = int(self.config.get("mock_env_seed", -1))
self._run_random = random.Random(seed or None) if seed >= 0 else None
self._status_random = random.Random(seed or None) if seed >= 0 else None
@@ -83,6 +227,67 @@ def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue
return {metric: float(score) for metric in self._metrics or []}
+ @property
+ def mock_trial_data(self) -> dict[int, MockTrialData]:
+ """
+ Get the mock trial data for all trials.
+
+ Returns
+ -------
+ dict[int, MockTrialData]
+ Dictionary of mock trial data keyed by trial ID.
+ """
+ return deepcopy(self._mock_trial_data)
+
+ def get_current_mock_trial_data(self) -> MockTrialData:
+ """
+ Gets mock trial data for the current trial ID.
+
+ If no (or missing) mock trial data is found, a new instance of
+ MockTrialData is created and later filled with random data.
+
+ Note
+ ----
+ This method must be called after the base :py:meth:`Environment.setup`
+ method is called to ensure the current ``trial_id`` is set.
+ """
+ trial_id = self.current_trial_id
+ mock_trial_data = self._mock_trial_data.get(trial_id)
+ if not mock_trial_data:
+ mock_trial_data = MockTrialData(
+ trial_id=trial_id,
+ setup=MockTrialPhaseData.from_dict(phase="setup", data=None),
+ run=MockTrialPhaseData.from_dict(phase="run", data=None),
+ status=MockTrialPhaseData.from_dict(phase="status", data=None),
+ teardown=MockTrialPhaseData.from_dict(phase="teardown", data=None),
+ )
+ # Save the generated data for later.
+ self._mock_trial_data[trial_id] = mock_trial_data
+ return mock_trial_data
+
+ def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool:
+ is_success = super().setup(tunables, global_config)
+ mock_trial_data = self.get_current_mock_trial_data()
+ if mock_trial_data.setup.sleep:
+ _LOG.debug("Sleeping for %s seconds", mock_trial_data.setup.sleep)
+ time.sleep(mock_trial_data.setup.sleep)
+ if mock_trial_data.setup.exception:
+ raise RuntimeError(
+ f"Mock trial data setup exception: {mock_trial_data.setup.exception}"
+ )
+ return is_success
+
+ def teardown(self) -> None:
+ mock_trial_data = self.get_current_mock_trial_data()
+ if mock_trial_data.teardown.sleep:
+ _LOG.debug("Sleeping for %s seconds", mock_trial_data.teardown.sleep)
+ time.sleep(mock_trial_data.teardown.sleep)
+ if mock_trial_data.teardown.exception:
+ raise RuntimeError(
+ f"Mock trial data teardown exception: {mock_trial_data.teardown.exception}"
+ )
+ super().teardown()
+
def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
"""
Produce mock benchmark data for one experiment.
@@ -99,8 +304,16 @@ def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
(status, timestamp, _) = result = super().run()
if not status.is_ready():
return result
- metrics = self._produce_metrics(self._run_random)
- return (Status.SUCCEEDED, timestamp, metrics)
+ mock_trial_data = self.get_current_mock_trial_data()
+ if mock_trial_data.run.sleep:
+ _LOG.debug("Sleeping for %s seconds", mock_trial_data.run.sleep)
+ time.sleep(mock_trial_data.run.sleep)
+ if mock_trial_data.run.exception:
+ raise RuntimeError(f"Mock trial data run exception: {mock_trial_data.run.exception}")
+ if mock_trial_data.run.metrics is None:
+ # If no metrics are provided, generate them.
+ mock_trial_data.run.metrics = self._produce_metrics(self._run_random)
+ return (mock_trial_data.run.status, timestamp, mock_trial_data.run.metrics)
def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]:
"""
@@ -116,10 +329,26 @@ def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]:
(status, timestamp, _) = result = super().status()
if not status.is_ready():
return result
- metrics = self._produce_metrics(self._status_random)
+ mock_trial_data = self.get_current_mock_trial_data()
+ if mock_trial_data.status.sleep:
+ _LOG.debug("Sleeping for %s seconds", mock_trial_data.status.sleep)
+ time.sleep(mock_trial_data.status.sleep)
+ if mock_trial_data.status.exception:
+ raise RuntimeError(
+ f"Mock trial data status exception: {mock_trial_data.status.exception}"
+ )
+ if mock_trial_data.status.metrics is None:
+ # If no metrics are provided, generate them.
+ # Note: we don't save these in the mock trial data as they may need
+ # to change to preserve backwards compatibility with previous tests.
+ metrics = self._produce_metrics(self._status_random)
+ else:
+ # If metrics are provided, use them.
+ # Note: current implementation uses the same metrics for all status
+ # calls of this trial.
+ metrics = mock_trial_data.status.metrics
return (
- # FIXME: this causes issues if we report RUNNING instead of READY
- Status.READY,
+ mock_trial_data.status.status,
timestamp,
[(timestamp, metric, score) for (metric, score) in metrics.items()],
)
diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py
index 6ac4674cfe1..d71eb661834 100644
--- a/mlos_bench/mlos_bench/environments/script_env.py
+++ b/mlos_bench/mlos_bench/environments/script_env.py
@@ -5,7 +5,7 @@
"""
Base scriptable benchmark environment.
-TODO: Document how variable propogation works in the script environments using
+TODO: Document how variable propagation works in the script environments using
shell_env_params, required_args, const_args, etc.
"""
diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py
index 6d76d7206c8..aa3b3e99c16 100644
--- a/mlos_bench/mlos_bench/environments/status.py
+++ b/mlos_bench/mlos_bench/environments/status.py
@@ -24,21 +24,37 @@ class Status(enum.Enum):
TIMED_OUT = 7
@staticmethod
- def from_str(status_str: Any) -> "Status":
- """Convert a string to a Status enum."""
- if not isinstance(status_str, str):
- _LOG.warning("Expected type %s for status: %s", type(status_str), status_str)
- status_str = str(status_str)
- if status_str.isdigit():
+ def parse(status: Any) -> "Status":
+ """
+ Convert the input to a Status enum.
+
+ Parameters
+ ----------
+ status : Any
+ The status to parse. This can be a string (or string convertible),
+ int, or Status enum.
+
+ Returns
+ -------
+ Status
+ The corresponding Status enum value or else UNKNOWN if the input is not
+ recognized.
+ """
+ if isinstance(status, Status):
+ return status
+ if not isinstance(status, str):
+ _LOG.warning("Expected type %s for status: %s", type(status), status)
+ status = str(status)
+ if status.isdigit():
try:
- return Status(int(status_str))
+ return Status(int(status))
except ValueError:
- _LOG.warning("Unknown status: %d", int(status_str))
+ _LOG.warning("Unknown status: %d", int(status))
try:
- status_str = status_str.upper().strip()
- return Status[status_str]
+ status = status.upper().strip()
+ return Status[status]
except KeyError:
- _LOG.warning("Unknown status: %s", status_str)
+ _LOG.warning("Unknown status: %s", status)
return Status.UNKNOWN
def is_good(self) -> bool:
@@ -113,4 +129,15 @@ def is_timed_out(self) -> bool:
Status.TIMED_OUT,
}
)
-"""The set of completed statuses."""
+"""
+The set of completed statuses.
+
+Includes all statuses that indicate the trial or experiment has finished, either
+successfully or not.
+This set is used to determine if a trial or experiment has reached a final state.
+This includes:
+- :py:attr:`.Status.SUCCEEDED`: The trial or experiment completed successfully.
+- :py:attr:`.Status.CANCELED`: The trial or experiment was canceled.
+- :py:attr:`.Status.FAILED`: The trial or experiment failed.
+- :py:attr:`.Status.TIMED_OUT`: The trial or experiment timed out.
+"""
diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py
index c728ed7fb20..353ace23f0e 100644
--- a/mlos_bench/mlos_bench/launcher.py
+++ b/mlos_bench/mlos_bench/launcher.py
@@ -55,8 +55,9 @@ def __init__(self, description: str, long_text: str = "", argv: list[str] | None
Other required_args values can also be pulled from shell environment
variables.
- For additional details, please see the website or the README.md files in
- the source tree:
+ For additional details, please see the documentation website or the
+ README.md files in the source tree:
+
"""
parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog)
diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py
index 44aa9a035e2..72b437b320e 100644
--- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py
+++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py
@@ -356,8 +356,10 @@ def _get_scores(
assert scores is not None
target_metrics: dict[str, float] = {}
for opt_target, opt_dir in self._opt_targets.items():
+ if opt_target not in scores:
+ raise ValueError(f"Score for {opt_target} not found in {scores}.")
val = scores[opt_target]
- assert val is not None
+ assert val is not None, f"Score for {opt_target} is None."
target_metrics[opt_target] = float(val) * opt_dir
return target_metrics
diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py
index 947e34a7da4..a1311b6f953 100644
--- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py
+++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py
@@ -14,6 +14,7 @@
import logging
import random
from collections.abc import Callable, Sequence
+from dataclasses import dataclass
from mlos_bench.environments.status import Status
from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
@@ -25,6 +26,15 @@
_LOG = logging.getLogger(__name__)
+@dataclass
+class RegisteredScore:
+ """A registered score for a trial."""
+
+ config: TunableGroups
+ score: dict[str, TunableValue] | None
+ status: Status
+
+
class MockOptimizer(TrackBestOptimizer):
"""Mock optimizer to test the Environment API."""
@@ -42,6 +52,39 @@ def __init__(
"float": lambda tunable: rnd.uniform(*tunable.range),
"int": lambda tunable: rnd.randint(*(int(x) for x in tunable.range)),
}
+ self._registered_scores: list[RegisteredScore] = []
+
+ @property
+ def registered_scores(self) -> list[RegisteredScore]:
+ """
+ Return the list of registered scores.
+
+ Notes
+ -----
+ Used for testing and validation.
+ """
+ return self._registered_scores
+
+ def register(
+ self,
+ tunables: TunableGroups,
+ status: Status,
+ score: dict[str, TunableValue] | None = None,
+ ) -> dict[str, float] | None:
+ # Track the registered scores for testing and validation.
+ score = score or {}
+ # Almost the same as _get_scores, but we don't adjust the direction here.
+ scores: dict[str, TunableValue] = {
+ k: float(v) for k, v in score.items() if k in self._opt_targets and v is not None
+ }
+ self._registered_scores.append(
+ RegisteredScore(
+ config=tunables.copy(),
+ score=scores,
+ status=status,
+ )
+ )
+ return super().register(tunables, status, score)
def bulk_register(
self,
diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py
index 1cd88fd5859..ee0d00757ae 100644
--- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py
+++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py
@@ -100,8 +100,9 @@ def __init__( # pylint: disable=too-many-arguments
self._optimizer = optimizer
self._storage = storage
self._root_env_config = root_env_config
- self._last_trial_id = -1
self._ran_trials: list[Storage.Trial] = []
+ self._longest_finished_trial_sequence_id = -1
+ self._registered_trial_ids: set[int] = set()
_LOG.debug("Scheduler instantiated: %s :: %s", self, config)
@@ -242,8 +243,15 @@ def __exit__(
self._in_context = False
return False # Do not suppress exceptions
- def start(self) -> None:
- """Start the scheduling loop."""
+ def _prepare_start(self) -> bool:
+ """
+ Prepare the scheduler for starting.
+
+ Notes
+ -----
+ This method is called by the :py:meth:`Scheduler.start` method.
+ It is split out mostly to allow for easier unit testing/mocking.
+ """
assert self.experiment is not None
_LOG.info(
"START: Experiment: %s Env: %s Optimizer: %s",
@@ -262,21 +270,43 @@ def start(self) -> None:
is_warm_up: bool = self.optimizer.supports_preload
if not is_warm_up:
_LOG.warning("Skip pending trials and warm-up: %s", self.optimizer)
+ return is_warm_up
+ def start(self) -> None:
+ """Start the scheduling loop."""
+ assert self.experiment is not None
+ is_warm_up = self._prepare_start()
not_done: bool = True
while not_done:
- _LOG.info("Optimization loop: Last trial ID: %d", self._last_trial_id)
- self.run_schedule(is_warm_up)
- not_done = self.add_new_optimizer_suggestions()
- self.assign_trial_runners(
- self.experiment.pending_trials(
- datetime.now(UTC),
- running=False,
- trial_runner_assigned=False,
- )
- )
+ not_done = self._execute_scheduling_step(is_warm_up)
is_warm_up = False
+ def _execute_scheduling_step(self, is_warm_up: bool) -> bool:
+ """
+ Perform a single scheduling step.
+
+ Notes
+ -----
+ This method is called by the :py:meth:`Scheduler.start` method.
+ It is split out mostly to allow for easier unit testing/mocking.
+ """
+ assert self.experiment is not None
+ _LOG.info(
+ "Optimization loop: Longest finished trial sequence ID: %d",
+ self._longest_finished_trial_sequence_id,
+ )
+ self.run_schedule(is_warm_up)
+ self.bulk_register_completed_trials()
+ not_done = self.add_new_optimizer_suggestions()
+ self.assign_trial_runners(
+ self.experiment.pending_trials(
+ datetime.now(UTC),
+ running=False,
+ trial_runner_assigned=False,
+ )
+ )
+ return not_done
+
def teardown(self) -> None:
"""
Tear down the TrialRunners/Environment(s).
@@ -309,10 +339,65 @@ def load_tunable_config(self, config_id: int) -> TunableGroups:
_LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2))
return tunables.copy()
+ def bulk_register_completed_trials(self) -> None:
+ """
+ Bulk register the most recent completed Trials in the Storage.
+
+ Notes
+ -----
+ This method is called after the Trials have been run (or started) and
+ the results have been recorded in the Storage by the TrialRunner(s).
+
+ It has logic to handle straggler Trials that finish out of order so
+ should be usable by both
+ :py:class:`~mlos_bench.schedulers.SyncScheduler` and async Schedulers.
+
+ See Also
+ --------
+ Scheduler.start
+ The main loop of the Scheduler.
+ Storage.Experiment.load
+ Load the results of the Trials based on some filtering criteria.
+ Optimizer.bulk_register
+ Register the results of the Trials in the Optimizer.
+ """
+ assert self.experiment is not None
+ # Load the results of the trials that have been run since the last time
+ # we queried the Optimizer.
+ # Note: We need to handle the case of straggler trials that finish out of order.
+ (trial_ids, configs, scores, status) = self.experiment.load(
+ last_trial_id=self._longest_finished_trial_sequence_id,
+ omit_registered_trial_ids=self._registered_trial_ids,
+ )
+ _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids)
+ self.optimizer.bulk_register(configs, scores, status)
+ # Mark those trials as registered so we don't load them again.
+ self._registered_trial_ids.update(trial_ids)
+ # Update the longest finished trial sequence ID.
+ self._longest_finished_trial_sequence_id = max(
+ [
+ self.experiment.get_longest_prefix_finished_trial_id(),
+ self._longest_finished_trial_sequence_id,
+ ],
+ default=self._longest_finished_trial_sequence_id,
+ )
+ # Remove trial ids that are older than the longest finished trial sequence ID.
+ # This is an optimization to avoid a long list of trial ids to omit from
+ # the load() operation or a long list of trial ids to maintain in memory.
+ self._registered_trial_ids = {
+ trial_id
+ for trial_id in self._registered_trial_ids
+ if trial_id > self._longest_finished_trial_sequence_id
+ }
+
def add_new_optimizer_suggestions(self) -> bool:
"""
Optimizer part of the loop.
+ Asks the :py:class:`~.Optimizer` for new suggestions and adds them to
+ the queue. This method is called after the trials have been run and the
+ results have been loaded into the optimizer.
+
Load the results of the executed trials into the
:py:class:`~.Optimizer`, suggest new configurations, and add them to the
queue.
@@ -323,16 +408,24 @@ def add_new_optimizer_suggestions(self) -> bool:
The return value indicates whether the optimization process should
continue to get suggestions from the Optimizer or not.
See Also: :py:meth:`~.Scheduler.not_done`.
- """
- assert self.experiment is not None
- (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id)
- _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids)
- self.optimizer.bulk_register(configs, scores, status)
- self._last_trial_id = max(trial_ids, default=self._last_trial_id)
+ Notes
+ -----
+ Subclasses can override this method to implement a more sophisticated
+ scheduling policy using the information obtained from the Optimizer.
+
+ See Also
+ --------
+ Scheduler.not_done
+ The stopping conditions for the optimization process.
+
+ Scheduler.bulk_register_completed_trials
+ Bulk register the most recent completed trials in the storage.
+ """
# Check if the optimizer has converged or not.
not_done = self.not_done()
if not_done:
+ # TODO: Allow scheduling multiple configs at once (e.g., in the case of idle workers).
tunables = self.optimizer.suggest()
self.add_trial_to_queue(tunables)
return not_done
diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py
index 75a84bf0b2e..81e8148d26e 100644
--- a/mlos_bench/mlos_bench/storage/base_storage.py
+++ b/mlos_bench/mlos_bench/storage/base_storage.py
@@ -26,7 +26,7 @@
import logging
from abc import ABCMeta, abstractmethod
-from collections.abc import Iterator, Mapping
+from collections.abc import Iterable, Iterator, Mapping
from contextlib import AbstractContextManager as ContextManager
from datetime import datetime
from types import TracebackType
@@ -324,27 +324,55 @@ def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]:
Telemetry data.
"""
+ @abstractmethod
+ def get_longest_prefix_finished_trial_id(self) -> int:
+ """
+ Calculate the last trial ID for the experiment.
+
+ This is used to determine the last trial ID that finished (failed or
+ successful) such that all Trials before it are also finished.
+ """
+
@abstractmethod
def load(
self,
last_trial_id: int = -1,
+ omit_registered_trial_ids: Iterable[int] | None = None,
) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]:
"""
Load (tunable values, benchmark scores, status) to warm-up the optimizer.
- If `last_trial_id` is present, load only the data from the (completed) trials
- that were scheduled *after* the given trial ID. Otherwise, return data from ALL
- merged-in experiments and attempt to impute the missing tunable values.
+ If `last_trial_id` is present, load only the data from the
+ (:py:meth:`completed `) trials that were
+ added *after* the given trial ID. Otherwise, return data from
+ ALL merged-in experiments and attempt to impute the missing tunable
+ values.
+
+ Additionally, if ``omit_registered_trial_ids`` is provided, omit the
+ trials matching those ids.
+
+ The parameters together allow us to efficiently load data from
+ finished trials that we haven't registered with the Optimizer yet
+ for bulk registering.
Parameters
----------
last_trial_id : int
(Optional) Trial ID to start from.
+ omit_registered_trial_ids : Iterable[int] | None = None,
+ (Optional) List of trial IDs to omit. If None, load all trials
+ after ``last_trial_id``.
Returns
-------
(trial_ids, configs, scores, status) : ([int], [dict], [dict] | None, [Status])
Trial ids, Tunable values, benchmark scores, and status of the trials.
+
+ See Also
+ --------
+ Storage.Experiment.get_longest_prefix_finished_trial_id :
+ Get the last (registered) trial ID for the experiment.
+ Scheduler.add_new_optimizer_suggestions
"""
@abstractmethod
diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py
index 97eb270c9d9..032cf9259d8 100644
--- a/mlos_bench/mlos_bench/storage/sql/common.py
+++ b/mlos_bench/mlos_bench/storage/sql/common.py
@@ -95,7 +95,7 @@ def get_trials(
config_id=trial.config_id,
ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
- status=Status.from_str(trial.status),
+ status=Status.parse(trial.status),
trial_runner_id=trial.trial_runner_id,
)
for trial in trials.fetchall()
diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py
index acc2a497b48..2fc31700242 100644
--- a/mlos_bench/mlos_bench/storage/sql/experiment.py
+++ b/mlos_bench/mlos_bench/storage/sql/experiment.py
@@ -8,7 +8,7 @@
import hashlib
import logging
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
from datetime import datetime
from typing import Any, Literal
@@ -153,13 +153,59 @@ def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]:
for row in cur_telemetry.fetchall()
]
+ # TODO: Add a test for this method.
+ def get_longest_prefix_finished_trial_id(self) -> int:
+ with self._engine.connect() as conn:
+ # TODO: Do this in a single query?
+
+ # Get the first (minimum) trial ID with an unfinished status.
+ first_unfinished_trial_id_stmt = (
+ self._schema.trial.select()
+ .with_only_columns(
+ func.min(self._schema.trial.c.trial_id),
+ )
+ .where(
+ self._schema.trial.c.exp_id == self._experiment_id,
+ func.not_(
+ self._schema.trial.c.status.in_(
+ [status.name for status in Status.completed_statuses()]
+ )
+ ),
+ )
+ )
+ max_trial_id = conn.execute(first_unfinished_trial_id_stmt).scalar()
+ if max_trial_id is not None:
+ # Return one less than the first unfinished trial ID - it should be
+ # finished (or not exist, which is fine as a limit).
+ return int(max_trial_id) - 1
+
+ # No unfinished trials, so *all* trials are completed - get the
+ # largest completed trial ID.
+ last_finished_trial_id = (
+ self._schema.trial.select()
+ .with_only_columns(
+ func.max(self._schema.trial.c.trial_id),
+ )
+ .where(
+ self._schema.trial.c.exp_id == self._experiment_id,
+ self._schema.trial.c.status.in_(
+ [status.name for status in Status.completed_statuses()]
+ ),
+ )
+ )
+ max_trial_id = conn.execute(last_finished_trial_id).scalar()
+ if max_trial_id is not None:
+ return int(max_trial_id)
+ # Else no trials yet exist for this experiment.
+ return -1
+
def load(
self,
last_trial_id: int = -1,
+ omit_registered_trial_ids: Iterable[int] | None = None,
) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]:
-
with self._engine.connect() as conn:
- cur_trials = conn.execute(
+ stmt = (
self._schema.trial.select()
.with_only_columns(
self._schema.trial.c.trial_id,
@@ -170,11 +216,7 @@ def load(
self._schema.trial.c.exp_id == self._experiment_id,
self._schema.trial.c.trial_id > last_trial_id,
self._schema.trial.c.status.in_(
- [
- Status.SUCCEEDED.name,
- Status.FAILED.name,
- Status.TIMED_OUT.name,
- ]
+ [status.name for status in Status.completed_statuses()]
),
)
.order_by(
@@ -182,13 +224,22 @@ def load(
)
)
+ # TODO: Add a test for this parameter.
+
+ # Note: if we have a very large number of trials, this may encounter
+ # SQL text length limits, so we may need to chunk this.
+ if omit_registered_trial_ids is not None:
+ stmt = stmt.where(self._schema.trial.c.trial_id.notin_(omit_registered_trial_ids))
+
+ cur_trials = conn.execute(stmt)
+
trial_ids: list[int] = []
configs: list[dict[str, Any]] = []
scores: list[dict[str, Any] | None] = []
status: list[Status] = []
for trial in cur_trials.fetchall():
- stat = Status.from_str(trial.status)
+ stat = Status.parse(trial.status)
status.append(stat)
trial_ids.append(trial.trial_id)
configs.append(
@@ -272,7 +323,7 @@ def get_trial_by_id(
config_id=trial.config_id,
trial_runner_id=trial.trial_runner_id,
opt_targets=self._opt_targets,
- status=Status.from_str(trial.status),
+ status=Status.parse(trial.status),
restoring=True,
config=config,
)
@@ -330,7 +381,7 @@ def pending_trials(
config_id=trial.config_id,
trial_runner_id=trial.trial_runner_id,
opt_targets=self._opt_targets,
- status=Status.from_str(trial.status),
+ status=Status.parse(trial.status),
restoring=True,
config=config,
)
diff --git a/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py b/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py
index ad8f9248acd..306fe00ffba 100644
--- a/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py
+++ b/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py
@@ -7,6 +7,17 @@
import pytest
+<<<<<<< HEAD
+from mlos_bench.config.schemas.config_schemas import ConfigSchema
+from mlos_bench.schedulers.base_scheduler import Scheduler
+from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.tests.config import locate_config_examples
+from mlos_bench.util import get_class_from_name
+
+_LOG = logging.getLogger(__name__)
+_LOG.setLevel(logging.DEBUG)
+
+=======
import mlos_bench.tests.optimizers.fixtures
import mlos_bench.tests.storage.sql.fixtures
from mlos_bench.config.schemas.config_schemas import ConfigSchema
@@ -26,6 +37,7 @@
_LOG.setLevel(logging.DEBUG)
# pylint: disable=redefined-outer-name
+>>>>>>> refactor/mock-scheduler-and-tests
# Get the set of configs to test.
CONFIG_TYPE = "schedulers"
@@ -43,6 +55,8 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]:
)
assert configs
+<<<<<<< HEAD
+=======
test_configs = locate_config_examples(
BUILTIN_TEST_CONFIG_PATH,
CONFIG_TYPE,
@@ -51,6 +65,7 @@ def filter_configs(configs_to_filter: list[str]) -> list[str]:
# assert test_configs
configs.extend(test_configs)
+>>>>>>> refactor/mock-scheduler-and-tests
@pytest.mark.parametrize("config_path", configs)
def test_load_scheduler_config_examples(
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-fields.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-fields.jsonc
new file mode 100644
index 00000000000..d36559cf334
--- /dev/null
+++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-fields.jsonc
@@ -0,0 +1,24 @@
+{
+ "class": "mlos_bench.environments.mock_env.MockEnv",
+ "config": {
+ "mock_trial_data": {
+ "1": {
+ "run": {
+ // bad types
+ "status": null,
+ "metrics": [],
+ "exception": null,
+ "sleep": "1",
+ },
+ // missing fields
+ "setup": {},
+ "teardown": {
+ "status": "UNKNOWN",
+ "metrics": {
+ "unexpected": "field"
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-ids.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-ids.jsonc
new file mode 100644
index 00000000000..400e557d0fa
--- /dev/null
+++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/invalid/mock_env-bad-trial-data-ids.jsonc
@@ -0,0 +1,13 @@
+{
+ "class": "mlos_bench.environments.mock_env.MockEnv",
+ "config": {
+ "mock_trial_data": {
+ // invalid trial id
+ "trial_id_1": {
+ "run": {
+ "status": "FAILED"
+ }
+ }
+ }
+ }
+}
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/unhandled/mock_env-trial-data-extras.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/unhandled/mock_env-trial-data-extras.jsonc
new file mode 100644
index 00000000000..ecdf4cd0f51
--- /dev/null
+++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/bad/unhandled/mock_env-trial-data-extras.jsonc
@@ -0,0 +1,15 @@
+{
+ "class": "mlos_bench.environments.mock_env.MockEnv",
+ "config": {
+ "mock_trial_data": {
+ "1": {
+ "new_phase": {
+ "status": "FAILED"
+ },
+ "run": {
+ "expected": "property"
+ }
+ }
+ }
+ }
+}
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/good/full/mock_env-full.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/good/full/mock_env-full.jsonc
index a00f8ca60c0..a23971f0362 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/good/full/mock_env-full.jsonc
+++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test-cases/good/full/mock_env-full.jsonc
@@ -25,6 +25,38 @@
"mock_env_metrics": [
"latency",
"cost"
- ]
+ ],
+ "mock_trial_data": {
+ "1": {
+ "setup": {
+ "sleep": 0.1
+ },
+ "status": {
+ "metrics": {
+ "latency": 0.2,
+ "cost": 0.3
+ }
+ },
+ "run": {
+ "sleep": 0.2,
+ "status": "SUCCEEDED",
+ "metrics": {
+ "latency": 0.1,
+ "cost": 0.2
+ }
+ },
+ "teardown": {
+ "sleep": 0.1
+ }
+ },
+ "2": {
+ "setup": {
+ "exception": "Some exception"
+ },
+ "teardown": {
+ "exception": "Some other exception"
+ }
+ }
+ }
}
}
diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test-cases/good/full/globals-with-schema.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/globals/test-cases/good/full/globals-with-schema.jsonc
index 58a0a31bb36..4ed580e09a9 100644
--- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test-cases/good/full/globals-with-schema.jsonc
+++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test-cases/good/full/globals-with-schema.jsonc
@@ -10,5 +10,20 @@
"mysql": ["mysql-innodb", "mysql-myisam", "mysql-binlog", "mysql-hugepages"]
},
"experiment_id": "ExperimentName",
- "trial_id": 1
+ "trial_id": 1,
+
+ "mock_trial_data": {
+ "1": {
+ "setup": {
+ "sleep": 1
+ },
+ "run": {
+ "status": "SUCCEEDED",
+ "metrics": {
+ "score": 0.9,
+ "color": "green"
+ }
+ }
+ }
+ }
}
diff --git a/mlos_bench/mlos_bench/tests/environments/test_status.py b/mlos_bench/mlos_bench/tests/environments/test_status.py
index 3c0a9bccf3c..785275825c0 100644
--- a/mlos_bench/mlos_bench/tests/environments/test_status.py
+++ b/mlos_bench/mlos_bench/tests/environments/test_status.py
@@ -51,16 +51,19 @@ def test_status_from_str_valid(input_str: str, expected_status: Status) -> None:
Expected Status enum value.
"""
assert (
- Status.from_str(input_str) == expected_status
+ Status.parse(input_str) == expected_status
), f"Expected {expected_status} for input: {input_str}"
# Check lowercase representation
assert (
- Status.from_str(input_str.lower()) == expected_status
+ Status.parse(input_str.lower()) == expected_status
), f"Expected {expected_status} for input: {input_str.lower()}"
+ assert (
+ Status.parse(expected_status) == expected_status
+ ), f"Expected {expected_status} for input: {expected_status}"
if input_str.isdigit():
# Also test the numeric representation
assert (
- Status.from_str(int(input_str)) == expected_status
+ Status.parse(int(input_str)) == expected_status
), f"Expected {expected_status} for input: {int(input_str)}"
@@ -83,7 +86,7 @@ def test_status_from_str_invalid(invalid_input: Any) -> None:
input.
"""
assert (
- Status.from_str(invalid_input) == Status.UNKNOWN
+ Status.parse(invalid_input) == Status.UNKNOWN
), f"Expected Status.UNKNOWN for invalid input: {invalid_input}"
diff --git a/mlos_bench/mlos_bench/tests/schedulers/__init__.py b/mlos_bench/mlos_bench/tests/schedulers/__init__.py
new file mode 100644
index 00000000000..4bc0076079f
--- /dev/null
+++ b/mlos_bench/mlos_bench/tests/schedulers/__init__.py
@@ -0,0 +1,5 @@
+#
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+#
+"""mlos_bench.tests.schedulers."""
diff --git a/mlos_bench/mlos_bench/tests/schedulers/conftest.py b/mlos_bench/mlos_bench/tests/schedulers/conftest.py
new file mode 100644
index 00000000000..df6bd2776fd
--- /dev/null
+++ b/mlos_bench/mlos_bench/tests/schedulers/conftest.py
@@ -0,0 +1,123 @@
+#
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+#
+"""Pytest fixtures for mlos_bench.schedulers tests."""
+# pylint: disable=redefined-outer-name
+
+import json
+import re
+
+import pytest
+from pytest import FixtureRequest
+
+from mlos_bench.environments.mock_env import MockEnv
+from mlos_bench.schedulers.trial_runner import TrialRunner
+from mlos_bench.services.config_persistence import ConfigPersistenceService
+from mlos_bench.tunables.tunable_groups import TunableGroups
+
+NUM_TRIAL_RUNNERS = 4
+
+
+@pytest.fixture
+def mock_env_config() -> dict:
+ """A config for a MockEnv with mock_trial_data."""
+ return {
+ "name": "Test MockEnv With Explicit Mock Trial Data",
+ "class": "mlos_bench.environments.mock_env.MockEnv",
+ "config": {
+ # Reference the covariant groups from the `tunable_groups` fixture.
+ # See Also:
+ # - mlos_bench/tests/conftest.py
+ # - mlos_bench/tests/tunable_groups_fixtures.py
+ "tunable_params": ["provision", "boot", "kernel"],
+ "mock_env_seed": -1,
+ "mock_env_range": [0, 10],
+ "mock_env_metrics": ["score"],
+ # TODO: Add more mock trial data here:
+ "mock_trial_data": {
+ "1": {
+ "run": {
+ "sleep": 0.15,
+ "status": "SUCCEEDED",
+ "metrics": {
+ "score": 1.0,
+ },
+ },
+ },
+ "2": {
+ "run": {
+ "sleep": 0.2,
+ "status": "SUCCEEDED",
+ "metrics": {
+ "score": 2.0,
+ },
+ },
+ },
+ "3": {
+ "run": {
+ "sleep": 0.1,
+ "status": "SUCCEEDED",
+ "metrics": {
+ "score": 3.0,
+ },
+ },
+ },
+ },
+ },
+ }
+
+
+@pytest.fixture
+def global_config(request: FixtureRequest) -> dict:
+ """A global config for a MockEnv."""
+ test_name = request.node.name
+ test_name = re.sub(r"[^a-zA-Z0-9]", "_", test_name)
+ experiment_id = f"TestExperiment-{test_name}"
+ return {
+ "experiment_id": experiment_id,
+ "trial_id": 1,
+ }
+
+
+@pytest.fixture
+def mock_env_json_config(mock_env_config: dict) -> str:
+ """A JSON string of the mock_env_config."""
+ return json.dumps(mock_env_config)
+
+
+@pytest.fixture
+def mock_env(
+ mock_env_json_config: str,
+ tunable_groups: TunableGroups,
+ global_config: dict,
+) -> MockEnv:
+ """A fixture to create a MockEnv instance using the mock_env_json_config."""
+ config_loader_service = ConfigPersistenceService()
+ mock_env = config_loader_service.load_environment(
+ mock_env_json_config,
+ tunable_groups,
+ service=config_loader_service,
+ global_config=global_config,
+ )
+ assert isinstance(mock_env, MockEnv)
+ return mock_env
+
+
+@pytest.fixture
+def trial_runners(
+ mock_env_json_config: str,
+ tunable_groups: TunableGroups,
+ global_config: dict,
+) -> list[TrialRunner]:
+ """A fixture to create a list of TrialRunner instances using the
+ mock_env_json_config.
+ """
+ config_loader_service = ConfigPersistenceService(global_config=global_config)
+ return TrialRunner.create_from_json(
+ config_loader=config_loader_service,
+ env_json=mock_env_json_config,
+ tunable_groups=tunable_groups,
+ num_trial_runners=NUM_TRIAL_RUNNERS,
+ global_config=global_config,
+ )
diff --git a/mlos_bench/mlos_bench/tests/schedulers/test_scheduler.py b/mlos_bench/mlos_bench/tests/schedulers/test_scheduler.py
new file mode 100644
index 00000000000..9a2cc1dbaec
--- /dev/null
+++ b/mlos_bench/mlos_bench/tests/schedulers/test_scheduler.py
@@ -0,0 +1,213 @@
+#
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+#
+"""Unit tests for :py:class:`mlos_bench.schedulers` and their internals."""
+
+import sys
+from logging import warning
+
+import pytest
+
+import mlos_bench.tests.optimizers.fixtures as optimizers_fixtures
+import mlos_bench.tests.storage.sql.fixtures as sql_storage_fixtures
+from mlos_bench.environments.mock_env import MockEnv
+from mlos_bench.optimizers.mock_optimizer import MockOptimizer
+from mlos_bench.schedulers.base_scheduler import Scheduler
+from mlos_bench.schedulers.trial_runner import TrialRunner
+from mlos_bench.storage.base_trial_data import TrialData
+from mlos_bench.storage.sql.storage import SqlStorage
+from mlos_core.tests import get_all_concrete_subclasses
+
+mock_opt = optimizers_fixtures.mock_opt
+sqlite_storage = sql_storage_fixtures.sqlite_storage
+
+DEBUG_WARNINGS_ENABLED = False
+
+# pylint: disable=redefined-outer-name
+
+
+def create_scheduler(
+ scheduler_type: type[Scheduler],
+ trial_runners: list[TrialRunner],
+ mock_opt: MockOptimizer,
+ sqlite_storage: SqlStorage,
+ global_config: dict,
+) -> Scheduler:
+ """Create a Scheduler instance using trial_runners, mock_opt, and sqlite_storage
+ fixtures.
+ """
+ env = trial_runners[0].environment
+ assert isinstance(env, MockEnv), "Environment is not a MockEnv instance."
+ max_trials = max(trial_id for trial_id in env.mock_trial_data.keys())
+ max_trials = min(max_trials, mock_opt.max_suggestions)
+
+ return scheduler_type(
+ config={
+ "max_trials": max_trials,
+ },
+ global_config=global_config,
+ trial_runners=trial_runners,
+ optimizer=mock_opt,
+ storage=sqlite_storage,
+ root_env_config="",
+ )
+
+
+def debug_warn(*args: object) -> None:
+ """Optionally issue warnings for debugging."""
+ if DEBUG_WARNINGS_ENABLED:
+ warning(*args)
+
+
+def mock_opt_has_registered_trial_score(
+ mock_opt: MockOptimizer,
+ trial_data: TrialData,
+) -> bool:
+ """Check that the MockOptimizer has registered a given MockTrialData."""
+ # pylint: disable=consider-using-any-or-all
+ # Split out for easier debugging.
+ for registered_score in mock_opt.registered_scores:
+ match = True
+ if registered_score.status != trial_data.status:
+ match = False
+ debug_warn(
+ "Registered status %s does not match trial data %s.",
+ registered_score.status,
+ trial_data.results_dict,
+ )
+ elif registered_score.score != trial_data.results_dict:
+ debug_warn(
+ "Registered score %s does not match trial data %s.",
+ registered_score.score,
+ trial_data.results_dict,
+ )
+ match = False
+ elif registered_score.config.get_param_values() != trial_data.tunable_config.config_dict:
+ debug_warn(
+ "Registered config %s does not match trial data %s.",
+ registered_score.config.get_param_values(),
+ trial_data.results_dict,
+ )
+ match = False
+ if match:
+ return True
+ return False
+
+
+scheduler_classes = get_all_concrete_subclasses(
+ Scheduler, # type: ignore[type-abstract]
+ pkg_name="mlos_bench",
+)
+assert scheduler_classes, "No Scheduler classes found in mlos_bench."
+
+
+@pytest.mark.parametrize(
+ "scheduler_class",
+ scheduler_classes,
+)
+@pytest.mark.skipif(
+ sys.platform == "win32",
+ reason="Skipping test on Windows - SQLite storage is not accessible in parallel tests there.",
+)
+def test_scheduler_with_mock_trial_data(
+ scheduler_class: type[Scheduler],
+ trial_runners: list[TrialRunner],
+ mock_opt: MockOptimizer,
+ sqlite_storage: SqlStorage,
+ global_config: dict,
+) -> None:
+ """
+ Full integration test for Scheduler: runs trials, checks storage, optimizer
+ registration, and internal bookkeeping.
+ """
+ # pylint: disable=too-many-locals
+
+ # Create the scheduler.
+ scheduler = create_scheduler(
+ scheduler_class,
+ trial_runners,
+ mock_opt,
+ sqlite_storage,
+ global_config,
+ )
+
+ root_env = scheduler.root_environment
+ experiment_id = root_env.experiment_id
+ assert isinstance(root_env, MockEnv), f"Root environment {root_env} is not a MockEnv."
+ assert root_env.mock_trial_data, "No mock trial data found in root environment."
+
+ # Run the scheduler
+ with scheduler:
+ scheduler.start()
+ scheduler.teardown()
+
+ # Now check the overall results.
+ ran_trials = {trial.trial_id for trial in scheduler.ran_trials}
+ assert (
+ experiment_id in sqlite_storage.experiments
+ ), f"Experiment {experiment_id} not found in storage."
+ exp_data = sqlite_storage.experiments[experiment_id]
+
+ for mock_trial_data in root_env.mock_trial_data.values():
+ trial_id = mock_trial_data.trial_id
+
+ # Check the bookkeeping for ran_trials.
+ assert trial_id in ran_trials, f"Trial {trial_id} not found in Scheduler.ran_trials."
+
+ # Check the results in storage.
+ assert trial_id in exp_data.trials, f"Trial {trial_id} not found in storage."
+ trial_data = exp_data.trials[trial_id]
+
+ # Check the results.
+ metrics = mock_trial_data.run.metrics
+ if metrics:
+ for result_key, result_value in metrics.items():
+ assert (
+ result_key in trial_data.results_dict
+ ), f"Result {result_key} not found in storage for trial {trial_data}."
+ assert (
+ trial_data.results_dict[result_key] == result_value
+ ), f"Result value for {result_key} does not match expected value."
+ # TODO: Should we check the reverse - no extra metrics were registered?
+ # else: metrics weren't explicit in the mock trial data, so we only
+ # check that a score was stored for the optimization target, but that's
+ # good to do regardless
+ for opt_target in mock_opt.targets:
+ assert (
+ opt_target in trial_data.results_dict
+ ), f"Result column {opt_target} not found in storage."
+ assert (
+ trial_data.results_dict[opt_target] is not None
+ ), f"Result value for {opt_target} is None."
+
+ # Check that the appropriate sleeps occurred.
+ trial_time_lb = 0.0
+ trial_time_lb += mock_trial_data.setup.sleep or 0
+ trial_time_lb += mock_trial_data.run.sleep or 0
+ trial_time_lb += mock_trial_data.status.sleep or 0
+ trial_time_lb += mock_trial_data.teardown.sleep or 0
+ assert trial_data.ts_end is not None, f"Trial {trial_id} has no end time."
+ trial_duration = trial_data.ts_end - trial_data.ts_start
+ trial_dur_secs = trial_duration.total_seconds()
+ assert (
+ trial_dur_secs >= trial_time_lb
+ ), f"Trial {trial_id} took less time ({trial_dur_secs}) than expected ({trial_time_lb}). "
+
+ # Check that the trial status matches what we expected.
+ assert (
+ trial_data.status == mock_trial_data.run.status
+ ), f"Trial {trial_id} status {trial_data.status} was not {mock_trial_data.run.status}."
+
+ # TODO: Check the trial status telemetry.
+
+ # Check the optimizer registration.
+ assert mock_opt_has_registered_trial_score(
+ mock_opt,
+ trial_data,
+ ), f"Trial {trial_id} was not registered in the optimizer."
+
+ # TODO: And check the intermediary results.
+ # 4. Check the bookkeeping for add_new_optimizer_suggestions and _last_trial_id.
+ # This last part may require patching and intercepting during the start()
+ # loop to validate in-progress book keeping instead of just overall.
diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py
index e07cf80c70a..f82833dd64f 100644
--- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py
@@ -3,7 +3,8 @@
# Licensed under the MIT License.
#
"""Unit tests for the storage subsystem."""
-from datetime import datetime, tzinfo
+from datetime import datetime, timedelta, tzinfo
+from random import random
import pytest
from pytz import UTC
@@ -157,3 +158,265 @@ def test_exp_trial_pending_3(
assert status == [Status.FAILED, Status.SUCCEEDED]
assert tunable_groups.copy().assign(configs[0]).reset() == trial_fail.tunables
assert tunable_groups.copy().assign(configs[1]).reset() == trial_succ.tunables
+
+
+def test_empty_get_longest_prefix_finished_trial_id(
+ storage: Storage,
+ exp_storage: Storage.Experiment,
+) -> None:
+ """Test that the longest prefix of finished trials is empty when no trials are
+ present.
+ """
+ assert not storage.experiments[
+ exp_storage.experiment_id
+ ].trials, "Expected no trials in the experiment."
+
+ # Retrieve the longest prefix of finished trials when no trials are present
+ longest_prefix_id = exp_storage.get_longest_prefix_finished_trial_id()
+
+ # Assert that the longest prefix is empty
+ assert (
+ longest_prefix_id == -1
+ ), f"Expected longest prefix to be -1, but got {longest_prefix_id}"
+
+
+def test_sync_success_get_longest_prefix_finished_trial_id(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+) -> None:
+ """Test that the longest prefix of finished trials is returned correctly when all
+ trial are finished.
+ """
+ timestamp = datetime.now(UTC)
+ config = {}
+ metrics = {metric: random() for metric in exp_storage.opt_targets}
+
+ # Create several trials
+ trials = [exp_storage.new_trial(tunable_groups, config=config) for _ in range(0, 4)]
+
+ # Mark some trials at the beginning and end as finished
+ trials[0].update(Status.SUCCEEDED, timestamp + timedelta(minutes=1), metrics=metrics)
+ trials[1].update(Status.FAILED, timestamp + timedelta(minutes=2), metrics=metrics)
+ trials[2].update(Status.TIMED_OUT, timestamp + timedelta(minutes=3), metrics=metrics)
+ trials[3].update(Status.CANCELED, timestamp + timedelta(minutes=4), metrics=metrics)
+
+ # Retrieve the longest prefix of finished trials starting from trial_id 1
+ longest_prefix_id = exp_storage.get_longest_prefix_finished_trial_id()
+
+ # Assert that the longest prefix includes only the first three trials
+ assert longest_prefix_id == trials[3].trial_id, (
+ f"Expected longest prefix to end at trial_id {trials[3].trial_id}, "
+ f"but got {longest_prefix_id}"
+ )
+
+
+def test_async_get_longest_prefix_finished_trial_id(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+) -> None:
+ """Test that the longest prefix of finished trials is returned correctly when trial
+ finish out of order.
+ """
+ timestamp = datetime.now(UTC)
+ config = {}
+ metrics = {metric: random() for metric in exp_storage.opt_targets}
+
+ # Create several trials
+ trials = [exp_storage.new_trial(tunable_groups, config=config) for _ in range(0, 10)]
+
+ # Mark some trials at the beginning and end as finished
+ trials[0].update(Status.SUCCEEDED, timestamp + timedelta(minutes=1), metrics=metrics)
+ trials[1].update(Status.FAILED, timestamp + timedelta(minutes=2), metrics=metrics)
+ trials[2].update(Status.TIMED_OUT, timestamp + timedelta(minutes=3), metrics=metrics)
+ trials[3].update(Status.CANCELED, timestamp + timedelta(minutes=4), metrics=metrics)
+ # Leave trials[3] to trials[7] as PENDING
+ trials[9].update(Status.SUCCEEDED, timestamp + timedelta(minutes=5), metrics=metrics)
+
+ # Retrieve the longest prefix of finished trials starting from trial_id 1
+ longest_prefix_id = exp_storage.get_longest_prefix_finished_trial_id()
+
+ # Assert that the longest prefix includes only the first three trials
+ assert longest_prefix_id == trials[3].trial_id, (
+ f"Expected longest prefix to end at trial_id {trials[3].trial_id}, "
+ f"but got {longest_prefix_id}"
+ )
+
+
+# TODO: Can we simplify this to use something like SyncScheduler and
+# bulk_register_completed_trials?
+# TODO: Update to use MockScheduler
+def test_exp_load_async(
+ exp_storage: Storage.Experiment,
+ tunable_groups: TunableGroups,
+) -> None:
+ """
+ Test the ``omit_registered_trial_ids`` argument of the ``Experiment.load()`` method.
+
+ Create several trials with mixed statuses (PENDING and completed).
+ Verify that completed trials included in a local set of registered configs
+ are omitted from the `load` operation.
+ """
+ # pylint: disable=too-many-locals,too-many-statements
+
+ last_trial_id = exp_storage.get_longest_prefix_finished_trial_id()
+ assert last_trial_id == -1, "Expected no trials in the experiment."
+ registered_trial_ids: set[int] = set()
+
+ # Load trials, omitting registered ones
+ trial_ids, configs, scores, status = exp_storage.load(
+ last_trial_id=last_trial_id,
+ omit_registered_trial_ids=registered_trial_ids,
+ )
+
+ assert trial_ids == []
+ assert configs == []
+ assert scores == []
+ assert status == []
+
+ # Create trials with mixed statuses
+ trial_1_success = exp_storage.new_trial(tunable_groups)
+ trial_2_failed = exp_storage.new_trial(tunable_groups)
+ trial_3_pending = exp_storage.new_trial(tunable_groups)
+ trial_4_timedout = exp_storage.new_trial(tunable_groups)
+ trial_5_pending = exp_storage.new_trial(tunable_groups)
+
+ # Update statuses for completed trials
+ trial_1_success.update(Status.SUCCEEDED, datetime.now(UTC), {"score": 95.0})
+ trial_2_failed.update(Status.FAILED, datetime.now(UTC), {"score": -1})
+ trial_4_timedout.update(Status.TIMED_OUT, datetime.now(UTC), {"score": -1})
+
+ # Now evaluate some different sequences of loading trials by simulating what
+ # we expect a Scheduler to do.
+ # See Also: Scheduler.add_new_optimizer_suggestions()
+
+ trial_ids, configs, scores, status = exp_storage.load(
+ last_trial_id=last_trial_id,
+ omit_registered_trial_ids=registered_trial_ids,
+ )
+
+ # Verify that all completed trials are returned.
+ completed_trials = [
+ trial_1_success,
+ trial_2_failed,
+ trial_4_timedout,
+ ]
+ assert trial_ids == [trial.trial_id for trial in completed_trials]
+ assert len(configs) == len(completed_trials)
+ assert status == [trial.status for trial in completed_trials]
+
+ last_trial_id = exp_storage.get_longest_prefix_finished_trial_id()
+ assert last_trial_id == trial_2_failed.trial_id, (
+ f"Expected longest prefix to end at trial_id {trial_2_failed.trial_id}, "
+ f"but got {last_trial_id}"
+ )
+ registered_trial_ids |= {completed_trial.trial_id for completed_trial in completed_trials}
+ registered_trial_ids = {i for i in registered_trial_ids if i > last_trial_id}
+
+ # Create some more trials and update their statuses.
+ # Note: we are leaving some trials in the middle in the PENDING state.
+ trial_6_canceled = exp_storage.new_trial(tunable_groups)
+ trial_7_success2 = exp_storage.new_trial(tunable_groups)
+ trial_6_canceled.update(Status.CANCELED, datetime.now(UTC), {"score": -1})
+ trial_7_success2.update(Status.SUCCEEDED, datetime.now(UTC), {"score": 90.0})
+
+ # Load trials, omitting registered ones
+ trial_ids, configs, scores, status = exp_storage.load(
+ last_trial_id=last_trial_id,
+ omit_registered_trial_ids=registered_trial_ids,
+ )
+ # Verify that only unregistered completed trials are returned
+ completed_trials = [
+ trial_6_canceled,
+ trial_7_success2,
+ ]
+ assert trial_ids == [trial.trial_id for trial in completed_trials]
+ assert len(configs) == len(completed_trials)
+ assert status == [trial.status for trial in completed_trials]
+
+ # Update our tracking of registered trials
+ last_trial_id = exp_storage.get_longest_prefix_finished_trial_id()
+ # Should still be the same as before since we haven't adjusted the PENDING
+ # trials at the beginning yet.
+ assert last_trial_id == trial_2_failed.trial_id, (
+ f"Expected longest prefix to end at trial_id {trial_2_failed.trial_id}, "
+ f"but got {last_trial_id}"
+ )
+ registered_trial_ids |= {completed_trial.trial_id for completed_trial in completed_trials}
+ registered_trial_ids = {i for i in registered_trial_ids if i > last_trial_id}
+
+ trial_ids, configs, scores, status = exp_storage.load(
+ last_trial_id=last_trial_id,
+ omit_registered_trial_ids=registered_trial_ids,
+ )
+
+ # Verify that only unregistered completed trials are returned
+ completed_trials = []
+ assert trial_ids == [trial.trial_id for trial in completed_trials]
+ assert len(configs) == len(completed_trials)
+ assert status == [trial.status for trial in completed_trials]
+
+ # Now update the PENDING trials to be TIMED_OUT.
+ trial_3_pending.update(Status.TIMED_OUT, datetime.now(UTC), {"score": -1})
+
+ # Load trials, omitting registered ones
+ trial_ids, configs, scores, status = exp_storage.load(
+ last_trial_id=last_trial_id,
+ omit_registered_trial_ids=registered_trial_ids,
+ )
+
+ # Verify that only unregistered completed trials are returned
+ completed_trials = [
+ trial_3_pending,
+ ]
+ assert trial_ids == [trial.trial_id for trial in completed_trials]
+ assert len(configs) == len(completed_trials)
+ assert status == [trial.status for trial in completed_trials]
+
+ # Update our tracking of registered trials
+ last_trial_id = exp_storage.get_longest_prefix_finished_trial_id()
+ assert last_trial_id == trial_4_timedout.trial_id, (
+ f"Expected longest prefix to end at trial_id {trial_4_timedout.trial_id}, "
+ f"but got {last_trial_id}"
+ )
+ registered_trial_ids |= {completed_trial.trial_id for completed_trial in completed_trials}
+ registered_trial_ids = {i for i in registered_trial_ids if i > last_trial_id}
+
+ # Load trials, omitting registered ones
+ trial_ids, configs, scores, status = exp_storage.load(
+ last_trial_id=last_trial_id,
+ omit_registered_trial_ids=registered_trial_ids,
+ )
+ # Verify that only unregistered completed trials are returned
+ completed_trials = []
+ assert trial_ids == [trial.trial_id for trial in completed_trials]
+ assert len(configs) == len(completed_trials)
+ assert status == [trial.status for trial in completed_trials]
+ # And that the longest prefix is still the same.
+ assert last_trial_id == trial_4_timedout.trial_id, (
+ f"Expected longest prefix to end at trial_id {trial_4_timedout.trial_id}, "
+ f"but got {last_trial_id}"
+ )
+
+ # Mark the last trial as finished.
+ trial_5_pending.update(Status.SUCCEEDED, datetime.now(UTC), {"score": 95.0})
+ # Load trials, omitting registered ones
+ trial_ids, configs, scores, status = exp_storage.load(
+ last_trial_id=last_trial_id,
+ omit_registered_trial_ids=registered_trial_ids,
+ )
+ # Verify that only unregistered completed trials are returned
+ completed_trials = [
+ trial_5_pending,
+ ]
+ assert trial_ids == [trial.trial_id for trial in completed_trials]
+ assert len(configs) == len(completed_trials)
+ assert status == [trial.status for trial in completed_trials]
+ # And that the longest prefix is now the last trial.
+ last_trial_id = exp_storage.get_longest_prefix_finished_trial_id()
+ assert last_trial_id == trial_7_success2.trial_id, (
+ f"Expected longest prefix to end at trial_id {trial_7_success2.trial_id}, "
+ f"but got {last_trial_id}"
+ )
+ registered_trial_ids |= {completed_trial.trial_id for completed_trial in completed_trials}
+ registered_trial_ids = {i for i in registered_trial_ids if i > last_trial_id}
+ assert registered_trial_ids == set()
diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
index 0bebeeff824..db6dc5fa2e3 100644
--- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
+++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py
@@ -30,7 +30,7 @@
# pylint: disable=redefined-outer-name
-@pytest.fixture
+@pytest.fixture(scope="function")
def sqlite_storage() -> Generator[SqlStorage]:
"""
Fixture for file based SQLite storage in a temporary directory.
diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py
index e22b045f1af..90bf84e7bb4 100644
--- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py
+++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py
@@ -24,14 +24,19 @@ def _trial_ids(trials: Iterator[Storage.Trial]) -> set[int]:
return {t.trial_id for t in trials}
-def test_schedule_trial(
+def test_storage_schedule(
storage: Storage,
exp_storage: Storage.Experiment,
tunable_groups: TunableGroups,
) -> None:
# pylint: disable=too-many-locals,too-many-statements
- """Schedule several trials for future execution and retrieve them later at certain
- timestamps.
+ """
+ Test some storage functions that schedule several trials for future execution and
+ retrieve them later at certain timestamps.
+
+ Notes
+ -----
+ This doesn't actually test the Scheduler.
"""
timestamp = datetime.now(UTC)
timedelta_1min = timedelta(minutes=1)