diff --git a/compiler_opt/es/blackbox_evaluator.py b/compiler_opt/es/blackbox_evaluator.py index 1208eb4d..f79365ee 100644 --- a/compiler_opt/es/blackbox_evaluator.py +++ b/compiler_opt/es/blackbox_evaluator.py @@ -17,6 +17,7 @@ import concurrent.futures import os import random +from typing import Any from absl import logging import gin @@ -29,6 +30,19 @@ from compiler_opt.rl import compilation_runner +def _extract_results(futures: list[concurrent.futures.Future]) -> list[Any]: + results = [None] * len(futures) + + for i in range(len(futures)): + if not futures[i].exception(): + results[i] = futures[i].result() + else: + logging.info('Error retrieving result from future: %s', + str(futures[i].exception())) + + return results + + class BlackboxEvaluator(metaclass=abc.ABCMeta): """Blockbox evaluator abstraction.""" @@ -46,19 +60,6 @@ def get_results( def set_baseline(self, pool: FixedWorkerPool) -> None: raise NotImplementedError() - def get_rewards( - self, results: list[concurrent.futures.Future]) -> list[float | None]: - rewards = [None] * len(results) - - for i in range(len(results)): - if not results[i].exception(): - rewards[i] = results[i].result() - else: - logging.info('Error retrieving result from future: %s', - str(results[i].exception())) - - return rewards - @gin.configurable class SamplingBlackboxEvaluator(BlackboxEvaluator): @@ -140,19 +141,20 @@ def set_baseline(self, pool: FixedWorkerPool) -> None: if self._baselines is not None: raise RuntimeError('The baseline has already been set.') self._load_samples() - results = self._launch_compilation_workers(pool) - self._baselines = super().get_rewards(results) + results_futures = self._launch_compilation_workers(pool) + self._baselines = _extract_results(results_futures) def get_rewards( - self, results: list[concurrent.futures.Future]) -> list[float | None]: + self, + results_futures: list[concurrent.futures.Future]) -> list[float | None]: if self._baselines is None: raise RuntimeError('The baseline has not been set.') - if len(results) != len(self._baselines): + if len(results_futures) != len(self._baselines): raise RuntimeError( 'The number of results does not match the number of baselines.') - policy_results = super().get_rewards(results) + policy_results = _extract_results(results_futures) rewards = [] for policy_result, baseline in zip(