Skip to content

Commit 55cfce9

Browse files
authored
Promote locals in tests to module scope (#90)
Local-scoped classes won't be pickle-able following an internal change.
1 parent 4566f1b commit 55cfce9

File tree

2 files changed

+40
-36
lines changed

2 files changed

+40
-36
lines changed

compiler_opt/distributed/local/local_worker_manager_test.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,46 @@
2424
from tf_agents.system import system_multiprocessing as multiprocessing
2525

2626

27-
class LocalWorkerManagerTest(absltest.TestCase):
27+
class JobNormal(Worker):
28+
"""Test worker."""
2829

29-
def test_pool(self):
30+
def __init__(self):
31+
self._token = 0
32+
33+
@classmethod
34+
def is_priority_method(cls, method_name: str) -> bool:
35+
return method_name == 'priority_method'
36+
37+
def priority_method(self):
38+
return f'priority {self._token}'
39+
40+
def get_token(self):
41+
return self._token
3042

31-
class Job(Worker):
32-
"""Test worker."""
43+
def set_token(self, value):
44+
self._token = value
3345

34-
def __init__(self):
35-
self._token = 0
3646

37-
@classmethod
38-
def is_priority_method(cls, method_name: str) -> bool:
39-
return method_name == 'priority_method'
47+
class JobFail(Worker):
4048

41-
def priority_method(self):
42-
return f'priority {self._token}'
49+
def __init__(self, wont_be_passed):
50+
self._arg = wont_be_passed
4351

44-
def get_token(self):
45-
return self._token
52+
def method(self):
53+
return self._arg
4654

47-
def set_token(self, value):
48-
self._token = value
4955

50-
with local_worker_manager.LocalWorkerPool(Job, 2) as pool:
56+
class JobSlow(Worker):
57+
58+
def method(self):
59+
time.sleep(3600)
60+
61+
62+
class LocalWorkerManagerTest(absltest.TestCase):
63+
64+
def test_pool(self):
65+
66+
with local_worker_manager.LocalWorkerPool(JobNormal, 2) as pool:
5167
p1 = pool[0]
5268
p2 = pool[1]
5369
set_futures = [p1.set_token(1), p2.set_token(2)]
@@ -66,28 +82,15 @@ def set_token(self, value):
6682

6783
def test_failure(self):
6884

69-
class Job(Worker):
70-
71-
def __init__(self, wont_be_passed):
72-
self._arg = wont_be_passed
73-
74-
def method(self):
75-
return self._arg
76-
77-
with local_worker_manager.LocalWorkerPool(Job, 2) as pool:
85+
with local_worker_manager.LocalWorkerPool(JobFail, 2) as pool:
7886
with self.assertRaises(concurrent.futures.CancelledError):
7987
# this will fail because we didn't pass the arg to the ctor, so the
8088
# worker hosting process will crash.
8189
pool[0].method().result()
8290

8391
def test_worker_crash_while_waiting(self):
8492

85-
class Job(Worker):
86-
87-
def method(self):
88-
time.sleep(3600)
89-
90-
with local_worker_manager.LocalWorkerPool(Job, 2) as pool:
93+
with local_worker_manager.LocalWorkerPool(JobSlow, 2) as pool:
9194
p = pool[0]
9295
f = p.method()
9396
self.assertFalse(f.done())

compiler_opt/rl/local_data_collector_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,15 @@ def collect_data(self, module_spec, tf_policy_path, reward_stat):
8989
sequence_examples=[], reward_stats={}, rewards=[], keys=[])
9090

9191

92-
class LocalDataCollectorTest(tf.test.TestCase):
92+
class MyRunner(compilation_runner.CompilationRunner):
9393

94-
def test_local_data_collector(self):
94+
def collect_data(self, *args, **kwargs):
95+
return mock_collect_data(*args, **kwargs)
9596

96-
class MyRunner(compilation_runner.CompilationRunner):
9797

98-
def collect_data(self, *args, **kwargs):
99-
return mock_collect_data(*args, **kwargs)
98+
class LocalDataCollectorTest(tf.test.TestCase):
99+
100+
def test_local_data_collector(self):
100101

101102
def create_test_iterator_fn():
102103

0 commit comments

Comments
 (0)