Skip to content

Commit 3004b74

Browse files
authored
test for generate_default_trace (#50)
Also: - Handle flags appearing as duplicate to pytest The flags appear duplicate because pytest loads all the test modules, and for tests for tools like generate_default_trace or extract_ir, which define flags, some "popular" flag names may be thus duplicate. - dropping tf_agents system_multiprocessing because it doesn't work well with pytest.
1 parent 89634da commit 3004b74

File tree

4 files changed

+123
-17
lines changed

4 files changed

+123
-17
lines changed

compiler_opt/tools/extract_ir_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717
# pylint: disable=protected-access
1818

19+
from absl import flags
1920
from absl.testing import absltest
2021

2122
from compiler_opt.tools import extract_ir
2223

24+
flags.FLAGS['num_workers'].allow_override = True
25+
2326

2427
class ExtractIrTest(absltest.TestCase):
2528

compiler_opt/tools/generate_default_trace.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from absl import flags
2828
from absl import logging
2929
import gin
30+
import multiprocessing
3031
import tensorflow as tf
31-
from tf_agents.system import system_multiprocessing as multiprocessing
3232

3333
from compiler_opt.rl import compilation_runner
3434
from compiler_opt.rl import problem_configuration
@@ -69,8 +69,17 @@
6969
Dict[str, compilation_runner.RewardStat]]]
7070

7171

72-
def worker(runner_cls: Type[compilation_runner.CompilationRunner],
73-
policy_path: str, work_queue: 'queue.Queue[Tuple[str, ...]]',
72+
def get_runner() -> compilation_runner.CompilationRunner:
73+
problem_config = registry.get_configuration()
74+
return problem_config.get_runner_type()(
75+
moving_average_decay_rate=0,
76+
additional_flags=(),
77+
delete_flags=('-split-dwarf-file', '-split-dwarf-output',
78+
'-fthinlto-index', '-fprofile-sample-use',
79+
'-fprofile-remapping-file'))
80+
81+
82+
def worker(policy_path: str, work_queue: 'queue.Queue[Tuple[str, ...]]',
7483
results_queue: 'queue.Queue[Optional[List[str]]]',
7584
key_filter: Optional[str]):
7685
"""Describes the job each paralleled worker process does.
@@ -87,12 +96,7 @@ def worker(runner_cls: Type[compilation_runner.CompilationRunner],
8796
results_queue: the queue where results are deposited.
8897
key_filter: regex filter for key names to include, or None to include all.
8998
"""
90-
runner = runner_cls(
91-
moving_average_decay_rate=0,
92-
additional_flags=(),
93-
delete_flags=('-split-dwarf-file', '-split-dwarf-output',
94-
'-fthinlto-index', '-fprofile-sample-use',
95-
'-fprofile-remapping-file'))
99+
runner = get_runner()
96100
m = re.compile(key_filter) if key_filter else None
97101

98102
while True:
@@ -131,10 +135,6 @@ def main(_):
131135
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False)
132136
logging.info(gin.config_str())
133137

134-
problem_config = registry.get_configuration()
135-
runner = problem_config.get_runner_type()
136-
assert runner
137-
138138
with open(
139139
os.path.join(_DATA_PATH.value, 'module_paths'), 'r',
140140
encoding='utf-8') as f:
@@ -185,9 +185,9 @@ def main(_):
185185
# pylint:disable=g-complex-comprehension
186186
processes = [
187187
ctx.Process(
188-
target=functools.partial(
189-
worker, runner, _POLICY_PATH.value, work_queue, results_queue,
190-
_KEY_FILTER.value)) for _ in range(0, worker_count)
188+
target=functools.partial(worker, _POLICY_PATH.value, work_queue,
189+
results_queue, _KEY_FILTER.value))
190+
for _ in range(0, worker_count)
191191
]
192192
# pylint:enable=g-complex-comprehension
193193

@@ -230,4 +230,4 @@ def main(_):
230230

231231
if __name__ == '__main__':
232232
flags.mark_flag_as_required('data_path')
233-
multiprocessing.handle_main(functools.partial(app.run, main))
233+
app.run(main)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Tests for generate_default_trace."""
16+
17+
import os
18+
from unittest import mock
19+
20+
from absl import flags
21+
from absl.testing import absltest
22+
from absl.testing import flagsaver
23+
import gin
24+
import tensorflow as tf
25+
26+
# This is https://github.com/google/pytype/issues/764
27+
from google.protobuf import text_format # pytype: disable=pyi-error
28+
from compiler_opt.rl import compilation_runner
29+
from compiler_opt.tools import generate_default_trace
30+
31+
flags.FLAGS['num_workers'].allow_override = True
32+
33+
34+
class MockCompilationRunner(compilation_runner.CompilationRunner):
35+
"""A compilation runner just for test."""
36+
37+
def collect_data(self, file_paths, tf_policy_path, reward_stat):
38+
sequence_example_text = """
39+
feature_lists {
40+
feature_list {
41+
key: "feature_0"
42+
value {
43+
feature { int64_list { value: 1 } }
44+
feature { int64_list { value: 1 } }
45+
}
46+
}
47+
}"""
48+
sequence_example = text_format.Parse(sequence_example_text,
49+
tf.train.SequenceExample())
50+
51+
return compilation_runner.CompilationResult(
52+
sequence_examples=[sequence_example],
53+
reward_stats={
54+
'default':
55+
compilation_runner.RewardStat(
56+
default_reward=1, moving_average_reward=2)
57+
},
58+
rewards=[1.2],
59+
keys=['default'])
60+
61+
62+
class GenerateDefaultTraceTest(absltest.TestCase):
63+
64+
@mock.patch('compiler_opt.tools.generate_default_trace.get_runner')
65+
def test_api(self, mock_get_runner):
66+
tmp_dir = self.create_tempdir()
67+
module_names = ['a', 'b', 'c', 'd']
68+
69+
with tf.io.gfile.GFile(
70+
os.path.join(tmp_dir.full_path, 'module_paths'), 'w') as f:
71+
f.write('\n'.join(module_names))
72+
73+
for module_name in module_names:
74+
with tf.io.gfile.GFile(
75+
os.path.join(tmp_dir.full_path, module_name + '.bc'), 'w') as f:
76+
f.write(module_name)
77+
78+
mock_compilation_runner = MockCompilationRunner()
79+
mock_get_runner.return_value = mock_compilation_runner
80+
81+
with flagsaver.flagsaver(
82+
data_path=tmp_dir.full_path,
83+
num_workers=2,
84+
output_path=os.path.join(tmp_dir.full_path, 'output'),
85+
output_performance_path=os.path.join(tmp_dir.full_path,
86+
'output_performance'),
87+
):
88+
generate_default_trace.main(None)
89+
90+
def test_get_runner(self):
91+
with gin.unlock_config():
92+
gin.parse_config_files_and_bindings(
93+
config_files=['compiler_opt/rl/inlining/gin_configs/common.gin'],
94+
bindings=None)
95+
runner = generate_default_trace.get_runner()
96+
self.assertIsInstance(runner, compilation_runner.CompilationRunner)
97+
98+
99+
if __name__ == '__main__':
100+
absltest.main()

pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ filterwarnings =
1414

1515
# Issue #37
1616
ignore:Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec:UserWarning
17+
18+
# Not much to do about this, it's caused by gin
19+
ignore:Using or importing the ABCs from 'collections':DeprecationWarning

0 commit comments

Comments
 (0)