File tree Expand file tree Collapse file tree 2 files changed +8
-6
lines changed
Expand file tree Collapse file tree 2 files changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -148,7 +148,7 @@ class RewardActor(ForgeActor):
148148 @endpoint
149149 async def evaluate_response (
150150 self , prompt : str , response : str , target : str
151- ) -> dict [str , float ]:
151+ ) -> ( dict [str , float ], float ) :
152152 total_rewards = 0.0
153153 reward_breakdown = {} # reward breakdown by function
154154 for reward_fn in self .reward_functions :
@@ -189,9 +189,8 @@ async def evaluate_response(
189189 Reduce .SUM ,
190190 )
191191
192- avg_reward = total_rewards / len (self .reward_functions )
193- reward_breakdown ["reward" ] = avg_reward
194- return reward_breakdown
192+ avg_reward : float = total_rewards / len (self .reward_functions )
193+ return reward_breakdown , avg_reward
195194
196195
197196@dataclass
@@ -397,10 +396,12 @@ async def continuous_rollouts():
397396 response = response .text ,
398397 completion = response ,
399398 )
400- episode .reward_breakdown = await reward_actor .evaluate_response .route (
399+ (
400+ episode .reward_breakdown ,
401+ episode .reward ,
402+ ) = await reward_actor .evaluate_response .route (
401403 prompt = prompt , response = response .text , target = target
402404 )
403- episode .reward = episode .reward_breakdown ["reward" ]
404405 episodes .append (episode )
405406
406407 # Build input_ids for reference logprobs
Original file line number Diff line number Diff line change @@ -223,6 +223,7 @@ def record_episode_sample(table_name: str, episode):
223223 ** (
224224 episode .reward_breakdown or {}
225225 ), # per-fn breakdown including the average reward
226+ "reward" : episode .reward ,
226227 "advantage" : episode .advantage ,
227228 "request_len" : episode .request_len ,
228229 "response_len" : episode .response_len ,
You can’t perform that action at this time.
0 commit comments