def _overwrite_trajectory_reward(sequence_example: tf.train.SequenceExample,
reward: float) -> tf.train.SequenceExample:
"""Overwrite the reward in the trace (sequence_example) with the given one.
Args:
sequence_example: A tf.SequenceExample proto describing compilation trace.
reward: The reward to overwrite with.
Returns:
The tf.SequenceExample proto after post-processing.
"""
sequence_length = len(
next(iter(sequence_example.feature_lists.feature_list.values())).feature)
reward_list = sequence_example.feature_lists.feature_list['reward']
for _ in range(sequence_length):
added_feature = reward_list.feature.add()
added_feature.float_list.value.append(reward)
return sequence_example