Skip to content

Commit 4c590d6

Browse files
Add option to save best model in ES (#476)
This patch adds a config option to BlackboxLearnerConfig that enables saving the best model when a new best comes up. This patch also adds in support in BlackboxLearner to actually save the model to the save directory.
1 parent d9cdc38 commit 4c590d6

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

compiler_opt/es/blackbox_learner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class BlackboxLearnerConfig:
7676
# Learning rate
7777
step_size: float
7878

79+
# Whether or not to save a policy if it has the greatest reward seen so far.
80+
save_best_policy: bool
81+
7982

8083
def _prune_skipped_perturbations(perturbations: list[npt.NDArray[np.float32]],
8184
rewards: list[float | None]):
@@ -152,6 +155,7 @@ def __init__(self,
152155
self._step = initial_step
153156
self._deadline = deadline
154157
self._seed = seed
158+
self._global_max_reward = 0.0
155159

156160
self._summary_writer = tf.summary.create_file_writer(output_dir)
157161

@@ -270,6 +274,18 @@ def run_step(self, pool: FixedWorkerPool) -> None:
270274
self._log_rewards(rewards)
271275
self._log_tf_summary(rewards)
272276

277+
if self._config.save_best_policy and np.max(
278+
rewards) > self._global_max_reward:
279+
self._global_max = np.max(rewards)
280+
logging.info('Found new best model with reward %f at step '
281+
'%d, saving.', self._global_max, self._step)
282+
max_index = np.argmax(rewards)
283+
perturbation = initial_perturbations[max_index]
284+
self._policy_saver_fn(
285+
parameters=self._model_weights + perturbation,
286+
policy_name=f'best_policy_{self._global_max}_step_{self._step}',
287+
)
288+
273289
self._save_model()
274290

275291
self._step += 1

compiler_opt/es/blackbox_learner_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def setUp(self):
6565
evaluator=blackbox_evaluator.SamplingBlackboxEvaluator,
6666
total_num_perturbations=3,
6767
precision_parameter=1,
68-
step_size=1.0)
68+
step_size=1.0,
69+
save_best_policy=False)
6970

7071
self._cps = corpus.create_corpus_for_testing(
7172
location=tempfile.gettempdir(),

0 commit comments

Comments
 (0)