diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index eaa5527c6d..b711388609 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -200,8 +200,9 @@ def __enter__(self) -> "Scheduler": _LOG.debug("Scheduler START :: %s", self) assert self.experiment is None assert not self._in_context - for trial_runner in self._trial_runners.values(): - trial_runner.__enter__() + # NOTE: We delay entering the context of trial_runners until it's time + # to run the trial in order to avoid incompatibilities with + # multiprocessing.Pool. self._optimizer.__enter__() # Start new or resume the existing experiment. Verify that the # experiment configuration is compatible with the previous runs. @@ -235,7 +236,8 @@ def __exit__( self._experiment.__exit__(ex_type, ex_val, ex_tb) self._optimizer.__exit__(ex_type, ex_val, ex_tb) for trial_runner in self._trial_runners.values(): - trial_runner.__exit__(ex_type, ex_val, ex_tb) + # TrialRunners should have already exited their context after running the Trial. + assert not trial_runner._in_context # pylint: disable=protected-access self._experiment = None self._in_context = False return False # Do not suppress exceptions @@ -267,7 +269,8 @@ def teardown(self) -> None: if self._do_teardown: for trial_runner in self._trial_runners.values(): assert not trial_runner.is_running - trial_runner.teardown() + with trial_runner: + trial_runner.teardown() def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]: """Get the best observation from the optimizer.""" diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index 4b864942dc..f450b28b8f 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -39,5 +39,6 @@ def run_trial(self, trial: Storage.Trial) -> None: super().run_trial(trial) # In the sync scheduler we run each trial on its own TrialRunner in sequence. trial_runner = self.get_trial_runner(trial) - trial_runner.run_trial(trial, self.global_config) - _LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner) + with trial_runner: + trial_runner.run_trial(trial, self.global_config) + _LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner)