diff --git a/mlos_bench/mlos_bench/mlos_benchd.py b/mlos_bench/mlos_bench/mlos_benchd.py new file mode 100644 index 0000000000..88d49af62c --- /dev/null +++ b/mlos_bench/mlos_bench/mlos_benchd.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +mlos_bench background execution daemon. + +This script is responsible for polling the storage for runnable experiments and +executing them in parallel. + +See the current ``--help`` `output for details. +""" +import argparse +import time +from concurrent.futures import ProcessPoolExecutor + +from mlos_bench.run import _main as mlos_bench_main +from mlos_bench.storage import from_config + + +def _main(args: argparse.Namespace) -> None: + storage = from_config(config=args.storage) + poll_interval = float(args.poll_interval) + num_workers = int(args.num_workers) + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + while True: + exp_id = storage.get_runnable_experiment() + if exp_id is None: + print(f"No runnable experiment found. Sleeping for {poll_interval} second(s).") + time.sleep(poll_interval) + continue + + exp = storage.experiments[exp_id] + root_env_config, _, _ = exp.root_env_config + + executor.submit( + mlos_bench_main, + [ + "--storage", + args.storage, + "--environment", + root_env_config, + "--experiment_id", + exp_id, + ], + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="mlos_benchd") + parser.add_argument( + "--storage", + required=True, + help="Path to the storage configuration file.", + ) + parser.add_argument( + "--num_workers", + required=False, + default=1, + help="Number of workers to use. Default is 1.", + ) + parser.add_argument( + "--poll_interval", + required=False, + default=1, + help="Polling interval in seconds. Default is 1.", + ) + _main(parser.parse_args()) diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index ba84675425..86ac76735c 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -23,6 +23,8 @@ """ import logging +import os +import platform from abc import ABCMeta, abstractmethod from collections.abc import Iterator, Mapping from contextlib import AbstractContextManager as ContextManager @@ -30,6 +32,8 @@ from types import TracebackType from typing import Any, Literal +from pytz import UTC + from mlos_bench.config.schemas import ConfigSchema from mlos_bench.dict_templater import DictTemplater from mlos_bench.environments.status import Status @@ -133,6 +137,17 @@ def experiment( # pylint: disable=too-many-arguments the results of the experiment and related data. """ + @abstractmethod + def get_runnable_experiment(self) -> str | None: + """ + Get the ID of the experiment that can be run. + + Returns + ------- + experiment_id : str | None + ID of the experiment that can be run. + """ + class Experiment(ContextManager, metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """ @@ -150,6 +165,7 @@ def __init__( # pylint: disable=too-many-arguments root_env_config: str, description: str, opt_targets: dict[str, Literal["min", "max"]], + ts_start: datetime | None = None, ): self._tunables = tunables.copy() self._trial_id = trial_id @@ -159,6 +175,11 @@ def __init__( # pylint: disable=too-many-arguments ) self._description = description self._opt_targets = opt_targets + self._ts_start = ts_start or datetime.now(UTC) + self._ts_end: datetime | None = None + self._status = Status.PENDING + self._driver_name: str | None = None + self._driver_pid: int | None = None self._in_context = False def __enter__(self) -> "Storage.Experiment": @@ -209,6 +230,9 @@ def _setup(self) -> None: This method is called by `Storage.Experiment.__enter__()`. """ + self._status = Status.RUNNING + self._driver_name = platform.node() + self._driver_pid = os.getpid() def _teardown(self, is_ok: bool) -> None: """ @@ -221,6 +245,11 @@ def _teardown(self, is_ok: bool) -> None: is_ok : bool True if there were no exceptions during the experiment, False otherwise. """ + if is_ok: + self._status = Status.SUCCEEDED + else: + self._status = Status.FAILED + self._ts_end = datetime.now(UTC) @property def experiment_id(self) -> str: @@ -394,6 +423,10 @@ def _new_trial( the results of the experiment trial run. """ + @abstractmethod + def save(self) -> None: + """Save the experiment to the storage, without running it.""" + class Trial(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """ diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 62daa0232c..e67742213e 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -41,6 +41,7 @@ def __init__( # pylint: disable=too-many-arguments root_env_config: str, description: str, opt_targets: dict[str, Literal["min", "max"]], + ts_start: datetime | None = None, ): super().__init__( tunables=tunables, @@ -49,12 +50,12 @@ def __init__( # pylint: disable=too-many-arguments root_env_config=root_env_config, description=description, opt_targets=opt_targets, + ts_start=ts_start, ) self._engine = engine self._schema = schema - def _setup(self) -> None: - super()._setup() + def _ensure_persisted(self) -> None: with self._engine.begin() as conn: # Get git info and the last trial ID for the experiment. # pylint: disable=not-callable @@ -90,6 +91,8 @@ def _setup(self) -> None: git_repo=self._git_repo, git_commit=self._git_commit, root_env_config=self._root_env_config, + ts_start=self._ts_start, + status=self._status.name, ) ) conn.execute( @@ -125,6 +128,39 @@ def _setup(self) -> None: exp_info.git_commit, ) + def save(self) -> None: + self._ensure_persisted() + + def _setup(self) -> None: + super()._setup() + self._ensure_persisted() + with self._engine.begin() as conn: + conn.execute( + self._schema.experiment.update() + .where(self._schema.experiment.c.exp_id == self._experiment_id) + .values( + { + self._schema.experiment.c.status: self._status.name, + self._schema.experiment.c.driver_name: self._driver_name, + self._schema.experiment.c.driver_pid: self._driver_pid, + } + ) + ) + + def _teardown(self, is_ok: bool) -> None: + super()._teardown(is_ok) + with self._engine.begin() as conn: + conn.execute( + self._schema.experiment.update() + .where(self._schema.experiment.c.exp_id == self._experiment_id) + .values( + { + self._schema.experiment.c.status: self._status.name, + self._schema.experiment.c.ts_end: self._ts_end, + } + ) + ) + def merge(self, experiment_ids: list[str]) -> None: _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) raise NotImplementedError("TODO: Merging experiments not implemented yet.") diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index b3bf63d0ed..813c0bad52 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -5,10 +5,14 @@ """Saving and restoring the benchmark data in SQL database.""" import logging +import platform +from datetime import datetime from typing import Literal -from sqlalchemy import URL, create_engine +from pytz import UTC +from sqlalchemy import URL, create_engine, exc +from mlos_bench.environments.status import Status from mlos_bench.services.base_service import Service from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.base_storage import Storage @@ -109,3 +113,48 @@ def experiments(self) -> dict[str, ExperimentData]: ) for exp in cur_exp.fetchall() } + + def get_runnable_experiment(self) -> str | None: + with self._engine.connect() as conn: + with conn.begin() as trans: + try: + experiment_row = conn.execute( + self._schema.experiment.select() + .where( + self._schema.experiment.c.status == Status.PENDING.name, + self._schema.experiment.c.driver_name.is_(None), + self._schema.experiment.c.ts_start <= datetime.now(UTC), + ) + .order_by(self._schema.experiment.c.ts_start.asc()) + .limit(1) + ).fetchone() + if experiment_row: + # try to grab + result = conn.execute( + self._schema.experiment.update() + .where( + self._schema.experiment.c.driver_name.is_(None), + self._schema.experiment.c.exp_id == experiment_row.exp_id, + ) + .values( + { + self._schema.experiment.c.driver_name: platform.node(), + self._schema.experiment.c.status: Status.READY.name, + } + ) + ) + if result: + # succeeded, commit the transaction and return + trans.commit() + # return this to calling code to spawn a new `mlos_bench` + # process to fork and execute this Experiment on this host + # in the background + return str(experiment_row.exp_id) + else: + # someone else probably grabbed it + trans.rollback() + except exc.SQLAlchemyError: + # probably a conflict + trans.rollback() + + return None