|
15 | 15 | """Test for local worker manager."""
|
16 | 16 |
|
17 | 17 | import concurrent.futures
|
18 |
| -import multiprocessing |
19 | 18 | import time
|
20 | 19 |
|
21 | 20 | from absl.testing import absltest
|
@@ -59,6 +58,25 @@ def method(self):
|
59 | 58 | time.sleep(3600)
|
60 | 59 |
|
61 | 60 |
|
| 61 | +class JobCounter(Worker): |
| 62 | + """Test worker.""" |
| 63 | + |
| 64 | + def __init__(self): |
| 65 | + self.times = [] |
| 66 | + |
| 67 | + @classmethod |
| 68 | + def is_priority_method(cls, method_name: str) -> bool: |
| 69 | + return method_name == 'get_times' |
| 70 | + |
| 71 | + def start(self): |
| 72 | + while True: |
| 73 | + self.times.append(time.time()) |
| 74 | + time.sleep(0.05) |
| 75 | + |
| 76 | + def get_times(self): |
| 77 | + return self.times |
| 78 | + |
| 79 | + |
62 | 80 | class LocalWorkerManagerTest(absltest.TestCase):
|
63 | 81 |
|
64 | 82 | def test_pool(self):
|
@@ -100,6 +118,34 @@ def test_worker_crash_while_waiting(self):
|
100 | 118 | with self.assertRaises(concurrent.futures.CancelledError):
|
101 | 119 | _ = f.result()
|
102 | 120 |
|
| 121 | + def test_pause_resume(self): |
| 122 | + |
| 123 | + with local_worker_manager.LocalWorkerPool(JobCounter, 1) as pool: |
| 124 | + p = pool[0] |
| 125 | + |
| 126 | + # Fill the q for 1 second |
| 127 | + p.start() |
| 128 | + time.sleep(1) |
| 129 | + |
| 130 | + # Then pause the process for 1 second |
| 131 | + p.pause() |
| 132 | + time.sleep(1) |
| 133 | + |
| 134 | + # Then resume the process and wait 1 more second |
| 135 | + p.resume() |
| 136 | + time.sleep(1) |
| 137 | + |
| 138 | + times = p.get_times().result() |
| 139 | + |
| 140 | + # If pause/resume worked, there should be a gap of at least 0.5 seconds. |
| 141 | + # Otherwise, this will throw an exception. |
| 142 | + self.assertNotEmpty(times) |
| 143 | + last_time = times[0] |
| 144 | + for cur_time in times: |
| 145 | + if cur_time - last_time > 0.5: |
| 146 | + return |
| 147 | + raise ValueError('Failed to find a 2 second gap in times.') |
| 148 | + |
103 | 149 |
|
104 | 150 | if __name__ == '__main__':
|
105 | 151 | multiprocessing.handle_test_main(absltest.main)
|
0 commit comments