Skip to content

Commit 24865b4

Browse files
Revert previous two commits
Revert "Revert "Accept decay rate parameter in MockCompilationRunner"" This reverts commit 94cb0d6. HEAD~1 != HEAD... Revert "Remove flag allow_override calls" This reverts commit 7091b15. See previous commit message.
1 parent 94cb0d6 commit 24865b4

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

compiler_opt/rl/imitation_learning/generate_bc_trajectories_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
from compiler_opt.rl import env
3434
from compiler_opt.rl import env_test
3535

36+
# flags.FLAGS['gin_files'].allow_override = True
37+
# flags.FLAGS['gin_bindings'].allow_override = True
38+
3639
_eps = 1e-5
3740

3841

compiler_opt/tools/generate_default_trace_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from unittest import mock
1818

19+
from absl import flags
1920
from absl.testing import absltest
2021
from absl.testing import flagsaver
2122
import gin
@@ -28,12 +29,17 @@
2829

2930
from tf_agents.system import system_multiprocessing as multiprocessing
3031

32+
flags.FLAGS['num_workers'].allow_override = True
33+
flags.FLAGS['gin_files'].allow_override = True
34+
flags.FLAGS['gin_bindings'].allow_override = True
35+
3136

3237
@gin.configurable(module='runners')
3338
class MockCompilationRunner(compilation_runner.CompilationRunner):
3439
"""A compilation runner just for test."""
3540

36-
def __init__(self, sentinel=None):
41+
def __init__(self, moving_average_decay_rate: float, sentinel=None):
42+
del moving_average_decay_rate # Unused.
3743
assert sentinel == 42
3844
super().__init__()
3945

compiler_opt/tools/generate_test_model_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,17 @@
1616
An integration test for model saving, to detect TFLite model conversion.
1717
"""
1818

19+
from absl import flags
1920
from absl.testing import absltest
2021
from absl.testing import flagsaver
2122
from absl.testing import parameterized
2223
import gin
2324

2425
from compiler_opt.tools import generate_test_model
2526

27+
flags.FLAGS['gin_files'].allow_override = True
28+
flags.FLAGS['gin_bindings'].allow_override = True
29+
2630

2731
def _get_test_settings():
2832
test_setting = []

0 commit comments

Comments
 (0)