|
14 | 14 | generate_default_empty_experience, |
15 | 15 | generate_reward_feedback, |
16 | 16 | parse_response, |
17 | | - process_messages_to_experience, |
| 17 | + process_messages_to_experience_async, |
18 | 18 | save_task_data, |
19 | 19 | validate_trajectory_format, |
20 | 20 | ) |
@@ -215,9 +215,9 @@ def _should_keep_for_sft(self, second_traj_format_valid: bool, re_explore_info: |
215 | 215 | or (re_explore_info["efficiency_improved"] and re_explore_info["new_reward"] >= 1.0) |
216 | 216 | ) |
217 | 217 |
|
218 | | - def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience: |
| 218 | + async def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience: |
219 | 219 | """Generate experience from SFT messages""" |
220 | | - return process_messages_to_experience(self.model, sft_messages, info=metrics) |
| 220 | + return await process_messages_to_experience_async(self.model, sft_messages, info=metrics) |
221 | 221 |
|
222 | 222 | async def run_async(self) -> List[Experience]: |
223 | 223 | """Run the RAFT alfworld workflow and return experiences""" |
@@ -245,7 +245,7 @@ async def run_async(self) -> List[Experience]: |
245 | 245 | # Handle first attempt success cases |
246 | 246 | if reward >= 1 and traj_format_valid: |
247 | 247 | print("✅ Task completed successfully in the first attempt!") |
248 | | - experience = process_messages_to_experience( |
| 248 | + experience = await process_messages_to_experience_async( |
249 | 249 | self.model, trajectory, info={"success": success, "reward": reward, "steps": steps} |
250 | 250 | ) |
251 | 251 | return [experience] |
@@ -275,7 +275,7 @@ async def run_async(self) -> List[Experience]: |
275 | 275 | kept_for_sft = self._should_keep_for_sft(second_traj_format_valid, re_explore_info) |
276 | 276 |
|
277 | 277 | if kept_for_sft: |
278 | | - experience = self._generate_experience_from_sft(sft_messages, metrics) |
| 278 | + experience = await self._generate_experience_from_sft(sft_messages, metrics) |
279 | 279 | experiences.append(experience) |
280 | 280 | print( |
281 | 281 | f"✅ Generated good training data: orig={reward}, steps={steps}, new={re_explore_info['new_reward']}, new_steps={re_explore_info['new_steps']}" |
|
0 commit comments