14
14
# limitations under the License.
15
15
"""Module for running compilation and collect training data."""
16
16
17
- import concurrent
17
+ import abc
18
18
import dataclasses
19
19
import json
20
- import multiprocessing
21
20
import subprocess
22
21
import threading
23
22
from typing import Dict , List , Optional , Tuple
24
23
25
24
from absl import flags
26
- import tensorflow as tf
27
-
25
+ from compiler_opt .distributed .worker import Worker , WorkerFuture
28
26
from compiler_opt .rl import constant
27
+ import tensorflow as tf
29
28
30
29
_COMPILATION_TIMEOUT = flags .DEFINE_integer (
31
30
'compilation_timeout' , 60 ,
@@ -122,18 +121,6 @@ def __init__(self):
122
121
Exception .__init__ (self )
123
122
124
123
125
- class ProcessCancellationToken :
126
-
127
- def __init__ (self ):
128
- self ._event = multiprocessing .Manager ().Event ()
129
-
130
- def signal (self ):
131
- self ._event .set ()
132
-
133
- def wait (self ):
134
- self ._event .wait ()
135
-
136
-
137
124
def kill_process_ignore_exceptions (p : 'subprocess.Popen[bytes]' ):
138
125
# kill the process and ignore exceptions. Exceptions would be thrown if the
139
126
# process has already been killed/finished (which is inherently in a race
@@ -160,6 +147,10 @@ def __init__(self):
160
147
self ._done = False
161
148
self ._lock = threading .Lock ()
162
149
150
+ def enable (self ):
151
+ with self ._lock :
152
+ self ._done = False
153
+
163
154
def register_process (self , p : 'subprocess.Popen[bytes]' ):
164
155
"""Register a process for potential cancellation."""
165
156
with self ._lock :
@@ -168,7 +159,7 @@ def register_process(self, p: 'subprocess.Popen[bytes]'):
168
159
return
169
160
kill_process_ignore_exceptions (p )
170
161
171
- def signal (self ):
162
+ def kill_all_processes (self ):
172
163
"""Cancel any pending work."""
173
164
with self ._lock :
174
165
self ._done = True
@@ -265,21 +256,31 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
265
256
assert not hasattr (self , 'sequence_examples' )
266
257
267
258
268
- class CompilationRunner :
269
- """Base class for collecting compilation data ."""
259
+ class CompilationRunnerStub ( metaclass = abc . ABCMeta ) :
260
+ """The interface of a stub to CompilationRunner, for type checkers ."""
270
261
271
- _POOL : concurrent .futures .ThreadPoolExecutor = None
262
+ @abc .abstractmethod
263
+ def collect_data (
264
+ self , file_paths : Tuple [str , ...], tf_policy_path : str ,
265
+ reward_stat : Optional [Dict [str , RewardStat ]]
266
+ ) -> WorkerFuture [CompilationResult ]:
267
+ raise NotImplementedError ()
272
268
273
- @staticmethod
274
- def init_pool ():
275
- """Worker process initialization."""
276
- CompilationRunner ._POOL = concurrent .futures .ThreadPoolExecutor ()
269
+ @abc .abstractmethod
270
+ def cancel_all_work (self ) -> WorkerFuture :
271
+ raise NotImplementedError ()
277
272
278
- @staticmethod
279
- def _get_pool ():
280
- """Internal API for fetching the cancellation token waiting pool."""
281
- assert CompilationRunner ._POOL
282
- return CompilationRunner ._POOL
273
+ @abc .abstractmethod
274
+ def enable (self ) -> WorkerFuture :
275
+ raise NotImplementedError ()
276
+
277
+
278
+ class CompilationRunner (Worker ):
279
+ """Base class for collecting compilation data."""
280
+
281
+ @classmethod
282
+ def is_priority_method (cls , method_name : str ) -> bool :
283
+ return method_name == 'cancel_all_work'
283
284
284
285
def __init__ (self ,
285
286
clang_path : Optional [str ] = None ,
@@ -302,40 +303,18 @@ def __init__(self,
302
303
self ._additional_flags = additional_flags
303
304
self ._delete_flags = delete_flags
304
305
self ._compilation_timeout = _COMPILATION_TIMEOUT .value
306
+ self ._cancellation_manager = WorkerCancellationManager ()
305
307
306
- def _get_cancellation_manager (
307
- self , cancellation_token : Optional [ProcessCancellationToken ]
308
- ) -> Optional [WorkerCancellationManager ]:
309
- """Convert the ProcessCancellationToken into a WorkerCancellationManager.
310
-
311
- The conversion also registers the ProcessCancellationToken wait() on a
312
- thread which will call the WorkerCancellationManager upon completion.
313
- Since the token is always signaled, the thread always completes its work.
314
-
315
- Args:
316
- cancellation_token: the ProcessCancellationToken to convert.
317
-
318
- Returns:
319
- a WorkerCancellationManager, if a ProcessCancellationToken was given.
320
- """
321
- if not cancellation_token :
322
- return None
323
- ret = WorkerCancellationManager ()
324
-
325
- def signaler ():
326
- cancellation_token .wait ()
327
- ret .signal ()
308
+ # re-allow the cancellation manager accept work.
309
+ def enable (self ):
310
+ self ._cancellation_manager .enable ()
328
311
329
- CompilationRunner . _get_pool (). submit ( signaler )
330
- return ret
312
+ def cancel_all_work ( self ):
313
+ self . _cancellation_manager . kill_all_processes ()
331
314
332
315
def collect_data (
333
- self ,
334
- file_paths : Tuple [str , ...],
335
- tf_policy_path : str ,
336
- reward_stat : Optional [Dict [str , RewardStat ]],
337
- cancellation_token : Optional [ProcessCancellationToken ] = None
338
- ) -> CompilationResult :
316
+ self , file_paths : Tuple [str , ...], tf_policy_path : str ,
317
+ reward_stat : Optional [Dict [str , RewardStat ]]) -> CompilationResult :
339
318
"""Collect data for the given IR file and policy.
340
319
341
320
Args:
@@ -355,14 +334,12 @@ def collect_data(
355
334
compilation_runner.ProcessKilledException is passed through.
356
335
ValueError if example under default policy and ml policy does not match.
357
336
"""
358
- cancellation_manager = self ._get_cancellation_manager (cancellation_token )
359
-
360
337
if reward_stat is None :
361
338
default_result = self ._compile_fn (
362
339
file_paths ,
363
340
tf_policy_path = '' ,
364
341
reward_only = bool (tf_policy_path ),
365
- cancellation_manager = cancellation_manager )
342
+ cancellation_manager = self . _cancellation_manager )
366
343
reward_stat = {
367
344
k : RewardStat (v [1 ], v [1 ]) for (k , v ) in default_result .items ()
368
345
}
@@ -372,7 +349,7 @@ def collect_data(
372
349
file_paths ,
373
350
tf_policy_path ,
374
351
reward_only = False ,
375
- cancellation_manager = cancellation_manager )
352
+ cancellation_manager = self . _cancellation_manager )
376
353
else :
377
354
policy_result = default_result
378
355
0 commit comments