@@ -211,16 +211,20 @@ class CompilationResult:
211
211
keys : List [str ]
212
212
213
213
def __post_init__ (self , sequence_examples : List [tf .train .SequenceExample ]):
214
- object .__setattr__ (self , 'serialized_sequence_examples' ,
215
- [x .SerializeToString () for x in sequence_examples ])
214
+ object .__setattr__ (
215
+ self , 'serialized_sequence_examples' ,
216
+ [x .SerializeToString () for x in sequence_examples if x is not None ])
216
217
lengths = [
217
218
len (next (iter (x .feature_lists .feature_list .values ())).feature )
218
219
for x in sequence_examples
220
+ if x is not None
219
221
]
220
222
object .__setattr__ (self , 'length' , sum (lengths ))
221
223
222
- assert (len (self .serialized_sequence_examples ) == len (self .rewards ) ==
223
- (len (self .keys )))
224
+ # TODO: is it necessary to return keys AND reward_stats(which has the keys)?
225
+ # sequence_examples' length could also just not be checked, this allows
226
+ # raw_reward_only to do less work
227
+ assert (len (sequence_examples ) == len (self .rewards ) == (len (self .keys )))
224
228
assert set (self .keys ) == set (self .reward_stats .keys ())
225
229
assert not hasattr (self , 'sequence_examples' )
226
230
@@ -230,9 +234,11 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):
230
234
231
235
@abc .abstractmethod
232
236
def collect_data (
233
- self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
234
- reward_stat : Optional [Dict [str , RewardStat ]]
235
- ) -> WorkerFuture [CompilationResult ]:
237
+ self ,
238
+ module_spec : corpus .ModuleSpec ,
239
+ tf_policy_path : str ,
240
+ reward_stat : Optional [Dict [str , RewardStat ]],
241
+ raw_reward_only : bool = False ) -> WorkerFuture [CompilationResult ]:
236
242
raise NotImplementedError ()
237
243
238
244
@abc .abstractmethod
@@ -275,17 +281,18 @@ def enable(self):
275
281
def cancel_all_work (self ):
276
282
self ._cancellation_manager .kill_all_processes ()
277
283
278
- def collect_data (
279
- self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
280
- reward_stat : Optional [Dict [str , RewardStat ]]) -> CompilationResult :
284
+ def collect_data (self ,
285
+ module_spec : corpus .ModuleSpec ,
286
+ tf_policy_path : str ,
287
+ reward_stat : Optional [Dict [str , RewardStat ]],
288
+ raw_reward_only = False ) -> CompilationResult :
281
289
"""Collect data for the given IR file and policy.
282
290
283
291
Args:
284
292
module_spec: a ModuleSpec.
285
293
tf_policy_path: path to the tensorflow policy.
286
294
reward_stat: reward stat of this module, None if unknown.
287
- cancellation_token: a CancellationToken through which workers may be
288
- signaled early termination
295
+ raw_reward_only: whether to return the raw reward value without examples.
289
296
290
297
Returns:
291
298
A CompilationResult. In particular:
@@ -311,7 +318,7 @@ def collect_data(
311
318
policy_result = self ._compile_fn (
312
319
module_spec ,
313
320
tf_policy_path ,
314
- reward_only = False ,
321
+ reward_only = raw_reward_only ,
315
322
cancellation_manager = self ._cancellation_manager )
316
323
else :
317
324
policy_result = default_result
@@ -326,6 +333,11 @@ def collect_data(
326
333
raise ValueError (
327
334
(f'Example { k } does not exist under default policy for '
328
335
f'module { module_spec .name } ' ))
336
+ if raw_reward_only :
337
+ sequence_example_list .append (None )
338
+ rewards .append (policy_reward )
339
+ keys .append (k )
340
+ continue
329
341
default_reward = reward_stat [k ].default_reward
330
342
moving_average_reward = reward_stat [k ].moving_average_reward
331
343
sequence_example = _overwrite_trajectory_reward (
0 commit comments