-
Notifications
You must be signed in to change notification settings - Fork 113
Description
why calculate reward_stat, I see llvm_trainer.train use reward from sequence_example.reward
"""sequence_example = _overwrite_trajectory_reward(
sequence_example=sequence_example,
reward=_calculate_reward(
policy=policy_reward, baseline=moving_average_reward))
"""
for k, v in policy_result.items():
sequence_example = v[0]
policy_reward = v[1]
if k not in reward_stat:
raise ValueError(
(f'Example {k} does not exist under default policy for '
f'cmd line: {final_cmd_line}'))
default_reward = reward_stat[k].default_reward
moving_average_reward = reward_stat[k].moving_average_reward
sequence_example = _overwrite_trajectory_reward(
sequence_example=sequence_example,
reward=_calculate_reward(
policy=policy_reward, baseline=moving_average_reward))
sequence_example_list.append(sequence_example)
reward_stat[k].moving_average_reward = (
moving_average_reward * self._moving_average_decay_rate +
policy_reward * (1 - self._moving_average_decay_rate))
rewards.append(
_calculate_reward(policy=policy_reward, baseline=default_reward))
policy_rewards.append(policy_reward)
keys.append(k)
result = CompilationResult(
sequence_examples=sequence_example_list,
reward_stats=reward_stat,
rewards=rewards,
policy_rewards=policy_rewards,
keys=keys,
model_id=model_id)
for observer in self._observers:
observer.observe(result)
return result