Skip to content

Commit c9ebbf2

Browse files
Compute relative reward in trace blackbox evaluator
This patch makes the trace blackbox evaluator return a relative reward rather than the raw reward. This makes the rewards actually meaningful, and also prevents overflowing calculatings in blackbox_optimizer which previously made the new model weights NaN. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #463
1 parent 8e06640 commit c9ebbf2

File tree

4 files changed

+41
-5
lines changed

4 files changed

+41
-5
lines changed

compiler_opt/es/blackbox_evaluator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from compiler_opt.rl import corpus
2424
from compiler_opt.es import blackbox_optimizers
2525
from compiler_opt.distributed import buffered_scheduler
26+
from compiler_opt.rl import compilation_runner
2627

2728

2829
class BlackboxEvaluator(metaclass=abc.ABCMeta):
@@ -159,3 +160,16 @@ def set_baseline(self, pool: FixedWorkerPool) -> None:
159160
f' got {len(futures)}')
160161

161162
self._baseline = futures[0].result()
163+
164+
def get_rewards(
165+
self, results: list[concurrent.futures.Future]) -> list[float | None]:
166+
rewards = []
167+
168+
for result in results:
169+
if result.exception() is not None:
170+
raise result.exception()
171+
172+
rewards.append(
173+
compilation_runner.calculate_reward(result.result(), self._baseline))
174+
175+
return rewards

compiler_opt/es/blackbox_evaluator_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,25 @@ def test_trace_set_baseline(self):
8888
evaluator.set_baseline(pool)
8989
# pylint: disable=protected-access
9090
self.assertAlmostEqual(evaluator._baseline, 10)
91+
92+
def test_trace_get_rewards(self):
93+
f1 = concurrent.futures.Future()
94+
f1.set_result(2)
95+
f2 = concurrent.futures.Future()
96+
f2.set_result(3)
97+
results = [f1, f2]
98+
test_corpus = corpus.create_corpus_for_testing(
99+
location=self.create_tempdir().full_path,
100+
elements=[corpus.ModuleSpec(name='name1', size=1)])
101+
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
102+
test_corpus, blackbox_optimizers.EstimatorType.FORWARD_FD,
103+
'fake_bb_trace_path', 'fake_function_index_path')
104+
105+
# pylint: disable=protected-access
106+
evaluator._baseline = 2
107+
rewards = evaluator.get_rewards(results)
108+
109+
# Only check for two decimal places as the reward calculation uses a
110+
# reasonably large delta (0.01) when calculating the difference to
111+
# prevent division by zero.
112+
self.assertSequenceAlmostEqual(rewards, [0, -0.5], 2)

compiler_opt/es/blackbox_learner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ def _save_model(self) -> None:
230230
def get_model_weights(self) -> npt.NDArray[np.float32]:
231231
return self._model_weights
232232

233+
def set_baseline(self, pool: FixedWorkerPool) -> None:
234+
self._evaluator.set_baseline(pool)
235+
233236
def run_step(self, pool: FixedWorkerPool) -> None:
234237
"""Run a single step of blackbox learning.
235238
This does not instantaneously return due to several I/O
@@ -245,12 +248,8 @@ def run_step(self, pool: FixedWorkerPool) -> None:
245248
p for p in initial_perturbations for p in (p, -p)
246249
]
247250

248-
# TODO(boomanaiden154): This should be adding the perturbation to
249-
# the existing model weights. That currently results in the model
250-
# weights all being NaN, presumably due to rewards not being scaled for
251-
# the regalloc_trace problem.
252251
perturbations_as_bytes = [
253-
perturbation.astype(np.float32).tobytes()
252+
(self._model_weights + perturbation).astype(np.float32).tobytes()
254253
for perturbation in initial_perturbations
255254
]
256255

compiler_opt/es/es_trainer_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def train(additional_compilation_flags=(),
215215
worker_class,
216216
count=learner_config.total_num_perturbations,
217217
worker_kwargs=dict(gin_config=gin.operative_config_str())) as pool:
218+
learner.set_baseline(pool)
218219
for _ in range(learner_config.total_steps):
219220
learner.run_step(pool)
220221

0 commit comments

Comments
 (0)