14
14
# limitations under the License.
15
15
"""Generate initial training data from the behavior of the current heuristic."""
16
16
17
+ import concurrent .futures
17
18
import contextlib
18
19
import functools
19
- import os
20
- import queue
21
20
import re
22
- import subprocess
23
21
from typing import Dict , List , Optional , Union , Tuple # pylint:disable=unused-import
24
22
25
23
from absl import app
26
24
from absl import flags
27
25
from absl import logging
28
26
import gin
29
- import multiprocessing
27
+
30
28
import tensorflow as tf
31
29
30
+ from compiler_opt .distributed import worker
31
+ from compiler_opt .distributed import buffered_scheduler
32
+ from compiler_opt .distributed .local import local_worker_manager
33
+
32
34
from compiler_opt .rl import compilation_runner
33
35
from compiler_opt .rl import corpus
34
36
from compiler_opt .rl import policy_saver
35
37
from compiler_opt .rl import registry
36
38
39
+ from tf_agents .system import system_multiprocessing as multiprocessing
40
+
37
41
# see https://bugs.python.org/issue33315 - we do need these types, but must
38
42
# currently use them as string annotations
39
43
@@ -76,64 +80,45 @@ def get_runner() -> compilation_runner.CompilationRunner:
76
80
return problem_config .get_runner_type ()(moving_average_decay_rate = 0 )
77
81
78
82
79
- def worker (policy_path : Optional [str ],
80
- work_queue : 'queue.Queue[corpus.LoadedModuleSpec]' ,
81
- results_queue : 'queue.Queue[ResultsQueueEntry]' ,
82
- key_filter : Optional [str ]):
83
- """Describes the job each paralleled worker process does.
83
+ class FilteringWorker (worker .Worker ):
84
+ """Worker that performs a computation and optionally filters the result.
84
85
85
- The worker picks a workitem from the work_queue, process it, and deposits
86
- a result on the results_queue, in either success or failure cases.
87
- The results_queue items are tuples (workitem, result). On failure, the result
88
- is None.
89
86
90
87
Args:
91
- runner: the data collector.
92
88
policy_path: the policy_path to generate trace with.
93
- work_queue: the queue of unprocessed work items.
94
- results_queue: the queue where results are deposited.
95
89
key_filter: regex filter for key names to include, or None to include all.
96
90
"""
97
- try :
98
- runner = get_runner ()
99
- m = re .compile (key_filter ) if key_filter else None
100
- policy = policy_saver .Policy .from_filesystem (
91
+
92
+ def __init__ (self , policy_path : Optional [str ], key_filter : Optional [str ]):
93
+ self ._policy_path = policy_path
94
+ self ._key_filter = re .compile (key_filter ) if key_filter else None
95
+ self ._runner = get_runner ()
96
+ self ._policy = policy_saver .Policy .from_filesystem (
101
97
policy_path ) if policy_path else None
102
- while True :
103
- try :
104
- loaded_module_spec = work_queue .get_nowait ()
105
- except queue .Empty :
106
- return
107
- try :
108
- data = runner .collect_data (
109
- loaded_module_spec = loaded_module_spec ,
110
- policy = policy ,
111
- reward_stat = None ,
112
- model_id = 0 )
113
- if not m :
114
- results_queue .put (
115
- (loaded_module_spec .name , data .serialized_sequence_examples ,
116
- data .reward_stats ))
117
- continue
118
- new_reward_stats = {}
119
- new_sequence_examples = []
120
- for k , sequence_example in zip (data .keys ,
121
- data .serialized_sequence_examples ):
122
- if not m .match (k ):
123
- continue
124
- new_reward_stats [k ] = data .reward_stats [k ]
125
- new_sequence_examples .append (sequence_example )
126
- results_queue .put (
127
- (loaded_module_spec .name , new_sequence_examples , new_reward_stats ))
128
- except (subprocess .CalledProcessError , subprocess .TimeoutExpired ,
129
- RuntimeError ):
130
- logging .error ('Failed to compile %s.' , loaded_module_spec .name )
131
- results_queue .put (None )
132
- except BaseException as e : # pylint: disable=broad-except
133
- results_queue .put (e )
134
-
135
-
136
- def main (_ ):
98
+
99
+ def compile_and_filter (
100
+ self , loaded_module_spec : corpus .LoadedModuleSpec
101
+ ) -> Tuple [str , List [str ], Dict [str , compilation_runner .RewardStat ]]:
102
+ data = self ._runner .collect_data (
103
+ loaded_module_spec = loaded_module_spec ,
104
+ policy = self ._policy ,
105
+ reward_stat = None ,
106
+ model_id = 0 )
107
+ if self ._key_filter is None :
108
+ return (loaded_module_spec .name , data .serialized_sequence_examples ,
109
+ data .reward_stats )
110
+ new_reward_stats = {}
111
+ new_sequence_examples = []
112
+ for k , sequence_example in zip (data .keys ,
113
+ data .serialized_sequence_examples ):
114
+ if not self ._key_filter .match (k ):
115
+ continue
116
+ new_reward_stats [k ] = data .reward_stats [k ]
117
+ new_sequence_examples .append (sequence_example )
118
+ return (loaded_module_spec .name , new_sequence_examples , new_reward_stats )
119
+
120
+
121
+ def main (worker_manager_class = local_worker_manager .LocalWorkerPoolManager ):
137
122
138
123
gin .parse_config_files_and_bindings (
139
124
_GIN_FILES .value , bindings = _GIN_BINDINGS .value , skip_unknown = False )
@@ -160,74 +145,60 @@ def main(_):
160
145
# other smaller files are processed in parallel
161
146
corpus_elements = cps .sample (k = sampled_modules , sort = True )
162
147
163
- worker_count = (
164
- min (os .cpu_count (), _NUM_WORKERS .value )
165
- if _NUM_WORKERS .value else os .cpu_count ())
166
-
167
148
tfrecord_context = (
168
149
tf .io .TFRecordWriter (_OUTPUT_PATH .value )
169
150
if _OUTPUT_PATH .value else contextlib .nullcontext ())
170
151
performance_context = (
171
152
tf .io .gfile .GFile (_OUTPUT_PERFORMANCE_PATH .value , 'w' )
172
153
if _OUTPUT_PERFORMANCE_PATH .value else contextlib .nullcontext ())
154
+ work = [
155
+ cps .load_module_spec (corpus_element ) for corpus_element in corpus_elements
156
+ ]
173
157
174
158
with tfrecord_context as tfrecord_writer :
175
159
with performance_context as performance_writer :
176
- ctx = multiprocessing .get_context ()
177
- m = ctx .Manager ()
178
- results_queue : 'queue.Queue[ResultsQueueEntry]' = m .Queue ()
179
- work_queue : 'queue.Queue[corpus.LoadedModuleSpec]' = m .Queue ()
180
- for corpus_element in corpus_elements :
181
- work_queue .put (cps .load_module_spec (corpus_element ))
182
-
183
- # pylint:disable=g-complex-comprehension
184
- processes = [
185
- ctx .Process (
186
- target = functools .partial (worker , _POLICY_PATH .value , work_queue ,
187
- results_queue , _KEY_FILTER .value ))
188
- for _ in range (0 , worker_count )
189
- ]
190
- # pylint:enable=g-complex-comprehension
191
-
192
- for p in processes :
193
- p .start ()
194
-
195
- total_successful_examples = 0
196
- total_work = len (corpus_elements )
197
- total_failed_examples = 0
198
- total_training_examples = 0
199
- for _ in range (total_work ):
200
- logging .log_every_n_seconds (logging .INFO ,
201
- '%d success, %d failed out of %d' , 10 ,
202
- total_successful_examples ,
203
- total_failed_examples , total_work )
204
-
205
- results = results_queue .get ()
206
- if isinstance (results , BaseException ):
207
- logging .fatal (results )
208
- if not results :
209
- total_failed_examples += 1
210
- continue
211
-
212
- total_successful_examples += 1
213
- module_name , records , reward_stat = results
214
- if tfrecord_writer :
215
- total_training_examples += len (records )
216
- for r in records :
217
- tfrecord_writer .write (r )
218
- if performance_writer :
219
- for key , value in reward_stat .items ():
220
- performance_writer .write (
221
- (f'{ module_name } ,{ key } ,{ value .default_reward } ,'
222
- f'{ value .moving_average_reward } \n ' ))
223
-
224
- print ((f'{ total_successful_examples } of { len (corpus_elements )} modules '
225
- f'succeeded, and { total_training_examples } trainining examples '
226
- 'written' ))
227
- for p in processes :
228
- p .join ()
160
+ with worker_manager_class (
161
+ FilteringWorker ,
162
+ _NUM_WORKERS .value ,
163
+ policy_path = _POLICY_PATH .value ,
164
+ key_filter = _KEY_FILTER .value ) as lwm :
165
+
166
+ _ , result_futures = buffered_scheduler .schedule_on_worker_pool (
167
+ action = lambda w , j : w .compile_and_filter (j ),
168
+ jobs = work ,
169
+ worker_pool = lwm )
170
+ total_successful_examples = 0
171
+ total_work = len (corpus_elements )
172
+ total_failed_examples = 0
173
+ total_training_examples = 0
174
+ not_done = result_futures
175
+ while not_done :
176
+ (done , not_done ) = concurrent .futures .wait (not_done , 10 )
177
+ succeeded = [
178
+ r for r in done if not r .cancelled () and r .exception () is None
179
+ ]
180
+ total_successful_examples += len (succeeded )
181
+ total_failed_examples += (len (done ) - len (succeeded ))
182
+ for r in succeeded :
183
+ module_name , records , reward_stat = r .result ()
184
+ if tfrecord_writer :
185
+ total_training_examples += len (records )
186
+ for r in records :
187
+ tfrecord_writer .write (r )
188
+ if performance_writer :
189
+ for key , value in reward_stat .items ():
190
+ performance_writer .write (
191
+ (f'{ module_name } ,{ key } ,{ value .default_reward } ,'
192
+ f'{ value .moving_average_reward } \n ' ))
193
+ logging .info ('%d success, %d failed out of %d' ,
194
+ total_successful_examples , total_failed_examples ,
195
+ total_work )
196
+
197
+ print ((f'{ total_successful_examples } of { len (corpus_elements )} modules '
198
+ f'succeeded, and { total_training_examples } trainining examples '
199
+ 'written' ))
229
200
230
201
231
202
if __name__ == '__main__' :
232
203
flags .mark_flag_as_required ('data_path' )
233
- app .run ( main )
204
+ multiprocessing . handle_main ( functools . partial ( app .run , main ) )
0 commit comments