@@ -219,6 +219,7 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
219
219
]
220
220
object .__setattr__ (self , 'length' , sum (lengths ))
221
221
222
+ # TODO: is it necessary to return keys AND reward_stats(which has the keys)?
222
223
assert (len (self .serialized_sequence_examples ) == len (self .rewards ) ==
223
224
(len (self .keys )))
224
225
assert set (self .keys ) == set (self .reward_stats .keys ())
@@ -228,6 +229,14 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
228
229
class CompilationRunnerStub (metaclass = abc .ABCMeta ):
229
230
"""The interface of a stub to CompilationRunner, for type checkers."""
230
231
232
+ @abc .abstractmethod
233
+ def collect_results (self ,
234
+ module_spec : corpus .ModuleSpec ,
235
+ tf_policy_path : str ,
236
+ collect_default_result : bool ,
237
+ reward_only : bool = False ) -> Tuple [Dict , Dict ]:
238
+ raise NotImplementedError ()
239
+
231
240
@abc .abstractmethod
232
241
def collect_data (
233
242
self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
@@ -275,6 +284,47 @@ def enable(self):
275
284
def cancel_all_work (self ):
276
285
self ._cancellation_manager .kill_all_processes ()
277
286
287
+ @staticmethod
288
+ def get_rewards (result : Dict ) -> List [float ]:
289
+ if len (result ) == 0 :
290
+ return []
291
+ return [v [1 ] for v in result .values ()]
292
+
293
+ def collect_results (self ,
294
+ module_spec : corpus .ModuleSpec ,
295
+ tf_policy_path : str ,
296
+ collect_default_result : bool ,
297
+ reward_only : bool = False ) -> Tuple [Dict , Dict ]:
298
+ """Collect data for the given IR file and policy.
299
+
300
+ Args:
301
+ module_spec: a ModuleSpec.
302
+ tf_policy_path: path to the tensorflow policy.
303
+ collect_default_result: whether to get the default result as well.
304
+ reward_only: whether to only collect the rewards in the results.
305
+
306
+ Returns:
307
+ A tuple of the default result and policy result.
308
+ """
309
+ default_result = None
310
+ policy_result = None
311
+ if collect_default_result :
312
+ default_result = self ._compile_fn (
313
+ module_spec ,
314
+ tf_policy_path = '' ,
315
+ reward_only = bool (tf_policy_path ) or reward_only ,
316
+ cancellation_manager = self ._cancellation_manager )
317
+ policy_result = default_result
318
+
319
+ if tf_policy_path :
320
+ policy_result = self ._compile_fn (
321
+ module_spec ,
322
+ tf_policy_path ,
323
+ reward_only = reward_only ,
324
+ cancellation_manager = self ._cancellation_manager )
325
+
326
+ return default_result , policy_result
327
+
278
328
def collect_data (
279
329
self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
280
330
reward_stat : Optional [Dict [str , RewardStat ]]) -> CompilationResult :
@@ -284,8 +334,6 @@ def collect_data(
284
334
module_spec: a ModuleSpec.
285
335
tf_policy_path: path to the tensorflow policy.
286
336
reward_stat: reward stat of this module, None if unknown.
287
- cancellation_token: a CancellationToken through which workers may be
288
- signaled early termination
289
337
290
338
Returns:
291
339
A CompilationResult. In particular:
@@ -297,25 +345,18 @@ def collect_data(
297
345
compilation_runner.ProcessKilledException is passed through.
298
346
ValueError if example under default policy and ml policy does not match.
299
347
"""
348
+ default_result , policy_result = self .collect_results (
349
+ module_spec ,
350
+ tf_policy_path ,
351
+ collect_default_result = reward_stat is None ,
352
+ reward_only = False )
300
353
if reward_stat is None :
301
- default_result = self ._compile_fn (
302
- module_spec ,
303
- tf_policy_path = '' ,
304
- reward_only = bool (tf_policy_path ),
305
- cancellation_manager = self ._cancellation_manager )
354
+ # TODO: Add structure to default_result and policy_result.
355
+ # get_rewards above should be updated/removed when this is resolved.
306
356
reward_stat = {
307
357
k : RewardStat (v [1 ], v [1 ]) for (k , v ) in default_result .items ()
308
358
}
309
359
310
- if tf_policy_path :
311
- policy_result = self ._compile_fn (
312
- module_spec ,
313
- tf_policy_path ,
314
- reward_only = False ,
315
- cancellation_manager = self ._cancellation_manager )
316
- else :
317
- policy_result = default_result
318
-
319
360
sequence_example_list = []
320
361
rewards = []
321
362
keys = []
0 commit comments