21
21
import random
22
22
import re
23
23
import subprocess
24
- from typing import Dict , List , Optional , Tuple , Type # pylint:disable=unused-import
24
+ from typing import Dict , List , Optional , Union , Tuple # pylint:disable=unused-import
25
25
26
26
from absl import app
27
27
from absl import flags
66
66
'gin_bindings' , [],
67
67
'Gin bindings to override the values set in the config files.' )
68
68
69
- ResultsQueueEntry = Optional [Tuple [str , List [str ],
70
- Dict [str , compilation_runner .RewardStat ]]]
69
+ ResultsQueueEntry = Union [Optional [Tuple [str , List [str ],
70
+ Dict [str ,
71
+ compilation_runner .RewardStat ]]],
72
+ BaseException ]
71
73
72
74
73
75
def get_runner () -> compilation_runner .CompilationRunner :
@@ -81,7 +83,7 @@ def get_runner() -> compilation_runner.CompilationRunner:
81
83
82
84
83
85
def worker (policy_path : str , work_queue : 'queue.Queue[corpus.ModuleSpec]' ,
84
- results_queue : 'queue.Queue[Optional[List[str]] ]' ,
86
+ results_queue : 'queue.Queue[ResultsQueueEntry ]' ,
85
87
key_filter : Optional [str ]):
86
88
"""Describes the job each paralleled worker process does.
87
89
@@ -97,35 +99,41 @@ def worker(policy_path: str, work_queue: 'queue.Queue[corpus.ModuleSpec]',
97
99
results_queue: the queue where results are deposited.
98
100
key_filter: regex filter for key names to include, or None to include all.
99
101
"""
100
- runner = get_runner ()
101
- m = re .compile (key_filter ) if key_filter else None
102
-
103
- while True :
104
- try :
105
- module_spec = work_queue .get_nowait ()
106
- except queue .Empty :
107
- return
108
- try :
109
- data = runner .collect_data (
110
- module_spec = module_spec , tf_policy_path = policy_path , reward_stat = None )
111
- if not m :
112
- results_queue .put ((module_spec .name , data .serialized_sequence_examples ,
113
- data .reward_stats ))
114
- continue
115
- new_reward_stats = {}
116
- new_sequence_examples = []
117
- for k , sequence_example in zip (data .keys ,
118
- data .serialized_sequence_examples ):
119
- if not m .match (k ):
102
+ try :
103
+ runner = get_runner ()
104
+ m = re .compile (key_filter ) if key_filter else None
105
+
106
+ while True :
107
+ try :
108
+ module_spec = work_queue .get_nowait ()
109
+ except queue .Empty :
110
+ return
111
+ try :
112
+ data = runner .collect_data (
113
+ module_spec = module_spec ,
114
+ tf_policy_path = policy_path ,
115
+ reward_stat = None )
116
+ if not m :
117
+ results_queue .put (
118
+ (module_spec .name , data .serialized_sequence_examples ,
119
+ data .reward_stats ))
120
120
continue
121
- new_reward_stats [k ] = data .reward_stats [k ]
122
- new_sequence_examples .append (sequence_example )
123
- results_queue .put (
124
- (module_spec .name , new_sequence_examples , new_reward_stats ))
125
- except (subprocess .CalledProcessError , subprocess .TimeoutExpired ,
126
- RuntimeError ):
127
- logging .error ('Failed to compile %s.' , module_spec .name )
128
- results_queue .put (None )
121
+ new_reward_stats = {}
122
+ new_sequence_examples = []
123
+ for k , sequence_example in zip (data .keys ,
124
+ data .serialized_sequence_examples ):
125
+ if not m .match (k ):
126
+ continue
127
+ new_reward_stats [k ] = data .reward_stats [k ]
128
+ new_sequence_examples .append (sequence_example )
129
+ results_queue .put (
130
+ (module_spec .name , new_sequence_examples , new_reward_stats ))
131
+ except (subprocess .CalledProcessError , subprocess .TimeoutExpired ,
132
+ RuntimeError ):
133
+ logging .error ('Failed to compile %s.' , module_spec .name )
134
+ results_queue .put (None )
135
+ except BaseException as e : # pylint: disable=broad-except
136
+ results_queue .put (e )
129
137
130
138
131
139
def main (_ ):
@@ -206,6 +214,8 @@ def main(_):
206
214
total_failed_examples , total_work )
207
215
208
216
results = results_queue .get ()
217
+ if isinstance (results , BaseException ):
218
+ logging .fatal (results )
209
219
if not results :
210
220
total_failed_examples += 1
211
221
continue
0 commit comments