Skip to content

Commit 8e06640

Browse files
Make calculate_reward publicly exposed in compilation_runner
This patch switches _calculate_reward from being an internal member to it being public in compilation_runner. This is necessary as we want to reuse it in TraceBlackboxEvaluator.
1 parent 711e230 commit 8e06640

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
'Put temporary files into given directory and keep them past exit.')
4444

4545

46-
def _calculate_reward(policy: float, baseline: float) -> float:
46+
def calculate_reward(policy: float, baseline: float) -> float:
4747
# This assumption allows us to imply baseline + constant.DELTA > 0.
4848
assert baseline >= 0
4949
return 1 - (policy + constant.DELTA) / (baseline + constant.DELTA)
@@ -465,14 +465,14 @@ def collect_data(self,
465465
moving_average_reward = reward_stat[k].moving_average_reward
466466
sequence_example = _overwrite_trajectory_reward(
467467
sequence_example=sequence_example,
468-
reward=_calculate_reward(
468+
reward=calculate_reward(
469469
policy=policy_reward, baseline=moving_average_reward))
470470
sequence_example_list.append(sequence_example)
471471
reward_stat[k].moving_average_reward = (
472472
moving_average_reward * self._moving_average_decay_rate +
473473
policy_reward * (1 - self._moving_average_decay_rate))
474474
rewards.append(
475-
_calculate_reward(policy=policy_reward, baseline=default_reward))
475+
calculate_reward(policy=policy_reward, baseline=default_reward))
476476
policy_rewards.append(policy_reward)
477477
keys.append(k)
478478

compiler_opt/rl/compilation_runner_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Tests for compiler_opt.rl.compilation_runner."""
1515

16+
import math
1617
import os
1718
import string
1819
import subprocess
@@ -254,6 +255,14 @@ def stop_and_start():
254255
# should be at least 1 second due to the pause.
255256
self.assertGreater(time.time() - start_time, 1)
256257

258+
def test_calculate_reward_zero_delta(self):
259+
reward = compilation_runner.calculate_reward(3, 0)
260+
self.assertTrue(math.isfinite(reward))
261+
262+
def test_calculate_reward(self):
263+
reward = compilation_runner.calculate_reward(1, 2)
264+
self.assertAlmostEqual(reward, 0.5, 2)
265+
257266

258267
if __name__ == '__main__':
259268
tf.test.main()

0 commit comments

Comments
 (0)