File tree Expand file tree Collapse file tree 2 files changed +12
-3
lines changed Expand file tree Collapse file tree 2 files changed +12
-3
lines changed Original file line number Diff line number Diff line change 43
43
'Put temporary files into given directory and keep them past exit.' )
44
44
45
45
46
- def _calculate_reward (policy : float , baseline : float ) -> float :
46
+ def calculate_reward (policy : float , baseline : float ) -> float :
47
47
# This assumption allows us to imply baseline + constant.DELTA > 0.
48
48
assert baseline >= 0
49
49
return 1 - (policy + constant .DELTA ) / (baseline + constant .DELTA )
@@ -465,14 +465,14 @@ def collect_data(self,
465
465
moving_average_reward = reward_stat [k ].moving_average_reward
466
466
sequence_example = _overwrite_trajectory_reward (
467
467
sequence_example = sequence_example ,
468
- reward = _calculate_reward (
468
+ reward = calculate_reward (
469
469
policy = policy_reward , baseline = moving_average_reward ))
470
470
sequence_example_list .append (sequence_example )
471
471
reward_stat [k ].moving_average_reward = (
472
472
moving_average_reward * self ._moving_average_decay_rate +
473
473
policy_reward * (1 - self ._moving_average_decay_rate ))
474
474
rewards .append (
475
- _calculate_reward (policy = policy_reward , baseline = default_reward ))
475
+ calculate_reward (policy = policy_reward , baseline = default_reward ))
476
476
policy_rewards .append (policy_reward )
477
477
keys .append (k )
478
478
Original file line number Diff line number Diff line change 13
13
# limitations under the License.
14
14
"""Tests for compiler_opt.rl.compilation_runner."""
15
15
16
+ import math
16
17
import os
17
18
import string
18
19
import subprocess
@@ -254,6 +255,14 @@ def stop_and_start():
254
255
# should be at least 1 second due to the pause.
255
256
self .assertGreater (time .time () - start_time , 1 )
256
257
258
+ def test_calculate_reward_zero_delta (self ):
259
+ reward = compilation_runner .calculate_reward (3 , 0 )
260
+ self .assertTrue (math .isfinite (reward ))
261
+
262
+ def test_calculate_reward (self ):
263
+ reward = compilation_runner .calculate_reward (1 , 2 )
264
+ self .assertAlmostEqual (reward , 0.5 , 2 )
265
+
257
266
258
267
if __name__ == '__main__' :
259
268
tf .test .main ()
You can’t perform that action at this time.
0 commit comments