24
24
from tf_agents .system import system_multiprocessing as multiprocessing
25
25
26
26
27
- class LocalWorkerManagerTest (absltest .TestCase ):
27
+ class JobNormal (Worker ):
28
+ """Test worker."""
28
29
29
- def test_pool (self ):
30
+ def __init__ (self ):
31
+ self ._token = 0
32
+
33
+ @classmethod
34
+ def is_priority_method (cls , method_name : str ) -> bool :
35
+ return method_name == 'priority_method'
36
+
37
+ def priority_method (self ):
38
+ return f'priority { self ._token } '
39
+
40
+ def get_token (self ):
41
+ return self ._token
30
42
31
- class Job ( Worker ):
32
- """Test worker."""
43
+ def set_token ( self , value ):
44
+ self . _token = value
33
45
34
- def __init__ (self ):
35
- self ._token = 0
36
46
37
- @classmethod
38
- def is_priority_method (cls , method_name : str ) -> bool :
39
- return method_name == 'priority_method'
47
+ class JobFail (Worker ):
40
48
41
- def priority_method (self ):
42
- return f'priority { self ._token } '
49
+ def __init__ (self , wont_be_passed ):
50
+ self ._arg = wont_be_passed
43
51
44
- def get_token (self ):
45
- return self ._token
52
+ def method (self ):
53
+ return self ._arg
46
54
47
- def set_token (self , value ):
48
- self ._token = value
49
55
50
- with local_worker_manager .LocalWorkerPool (Job , 2 ) as pool :
56
+ class JobSlow (Worker ):
57
+
58
+ def method (self ):
59
+ time .sleep (3600 )
60
+
61
+
62
+ class LocalWorkerManagerTest (absltest .TestCase ):
63
+
64
+ def test_pool (self ):
65
+
66
+ with local_worker_manager .LocalWorkerPool (JobNormal , 2 ) as pool :
51
67
p1 = pool [0 ]
52
68
p2 = pool [1 ]
53
69
set_futures = [p1 .set_token (1 ), p2 .set_token (2 )]
@@ -66,28 +82,15 @@ def set_token(self, value):
66
82
67
83
def test_failure (self ):
68
84
69
- class Job (Worker ):
70
-
71
- def __init__ (self , wont_be_passed ):
72
- self ._arg = wont_be_passed
73
-
74
- def method (self ):
75
- return self ._arg
76
-
77
- with local_worker_manager .LocalWorkerPool (Job , 2 ) as pool :
85
+ with local_worker_manager .LocalWorkerPool (JobFail , 2 ) as pool :
78
86
with self .assertRaises (concurrent .futures .CancelledError ):
79
87
# this will fail because we didn't pass the arg to the ctor, so the
80
88
# worker hosting process will crash.
81
89
pool [0 ].method ().result ()
82
90
83
91
def test_worker_crash_while_waiting (self ):
84
92
85
- class Job (Worker ):
86
-
87
- def method (self ):
88
- time .sleep (3600 )
89
-
90
- with local_worker_manager .LocalWorkerPool (Job , 2 ) as pool :
93
+ with local_worker_manager .LocalWorkerPool (JobSlow , 2 ) as pool :
91
94
p = pool [0 ]
92
95
f = p .method ()
93
96
self .assertFalse (f .done ())
0 commit comments