From cce156c126adc7ed88156272b150a1fc0ccdb248 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 7 Jan 2026 15:27:07 +0800 Subject: [PATCH 1/2] fix multiturnworkflow --- tests/explorer/workflow_test.py | 2 +- .../envs/alfworld/RAFT_alfworld_workflow.py | 4 +-- .../RAFT_reflect_alfworld_workflow.py | 10 +++--- .../workflows/envs/alfworld/RAFT_utils.py | 4 +-- .../envs/alfworld/alfworld_workflow.py | 2 +- .../workflows/envs/frozen_lake/workflow.py | 4 +-- .../envs/sciworld/sciworld_workflow.py | 2 +- .../envs/webshop/webshop_workflow.py | 4 +-- trinity/common/workflows/workflow.py | 33 +++++++++++++++++++ 9 files changed, 49 insertions(+), 16 deletions(-) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index ae5cb5a343..1150261bb8 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -160,7 +160,7 @@ async def run_async(self): await asyncio.sleep(0.1) memory.append({"role": "user", "content": content}) memory.append({"role": "assistant", "content": content.upper()}) - experience = self.process_messages_to_experience(memory, 0, {}) + experience = await self.process_messages_to_experience_async(memory, 0, {}) experience_list.append(experience) return experience_list diff --git a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py index 15050dc0eb..4bc34833cb 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py @@ -10,7 +10,7 @@ generate_default_empty_experience, get_jinja_env, parse_response, - process_messages_to_experience, + process_messages_to_experience_async, validate_trajectory_format, ) from trinity.common.workflows.workflow import Task, Workflow @@ -202,7 +202,7 @@ async def run_async(self) -> List[Experience]: if reward >= 1 and traj_format_valid: print("✅ Task completed successfully in the first attempt!") - experience = process_messages_to_experience( + experience = await process_messages_to_experience_async( self.model, trajectory, info={"success": success, "reward": reward, "steps": steps} ) return [experience] diff --git a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py index 589fb9a8e6..be2889a2b5 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py @@ -14,7 +14,7 @@ generate_default_empty_experience, generate_reward_feedback, parse_response, - process_messages_to_experience, + process_messages_to_experience_async, save_task_data, validate_trajectory_format, ) @@ -215,9 +215,9 @@ def _should_keep_for_sft(self, second_traj_format_valid: bool, re_explore_info: or (re_explore_info["efficiency_improved"] and re_explore_info["new_reward"] >= 1.0) ) - def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience: + async def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience: """Generate experience from SFT messages""" - return process_messages_to_experience(self.model, sft_messages, info=metrics) + return await process_messages_to_experience_async(self.model, sft_messages, info=metrics) async def run_async(self) -> List[Experience]: """Run the RAFT alfworld workflow and return experiences""" @@ -245,7 +245,7 @@ async def run_async(self) -> List[Experience]: # Handle first attempt success cases if reward >= 1 and traj_format_valid: print("✅ Task completed successfully in the first attempt!") - experience = process_messages_to_experience( + experience = await process_messages_to_experience_async( self.model, trajectory, info={"success": success, "reward": reward, "steps": steps} ) return [experience] @@ -275,7 +275,7 @@ async def run_async(self) -> List[Experience]: kept_for_sft = self._should_keep_for_sft(second_traj_format_valid, re_explore_info) if kept_for_sft: - experience = self._generate_experience_from_sft(sft_messages, metrics) + experience = await self._generate_experience_from_sft(sft_messages, metrics) experiences.append(experience) print( f"✅ Generated good training data: orig={reward}, steps={steps}, new={re_explore_info['new_reward']}, new_steps={re_explore_info['new_steps']}" diff --git a/trinity/common/workflows/envs/alfworld/RAFT_utils.py b/trinity/common/workflows/envs/alfworld/RAFT_utils.py index 46b6f356a6..5e57ba597a 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_utils.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_utils.py @@ -107,13 +107,13 @@ def create_alfworld_environment(game_file): raise ImportError(error_message) -def process_messages_to_experience(model, messages, info=None) -> Experience: +async def process_messages_to_experience_async(model, messages, info=None) -> Experience: """Convert messages to experience for training, with fallback to default empty experience""" if info is None: info = {} try: - converted_experience = model.convert_messages_to_experience(messages) + converted_experience = await model.convert_messages_to_experience_async(messages) metrics = {} for k, v in info.items(): diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 64fe07a6ce..9266483ee0 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -135,7 +135,7 @@ async def generate_env_inference_samples(self, env) -> List[Experience]: if done: final_reward = reward break - experience = self.process_messages_to_experience( + experience = await self.process_messages_to_experience_async( memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0} ) # Close the env to save cpu memory diff --git a/trinity/common/workflows/envs/frozen_lake/workflow.py b/trinity/common/workflows/envs/frozen_lake/workflow.py index 604b50282d..6d26a4775b 100644 --- a/trinity/common/workflows/envs/frozen_lake/workflow.py +++ b/trinity/common/workflows/envs/frozen_lake/workflow.py @@ -353,9 +353,9 @@ async def run_async(self) -> List[Experience]: # Create experience from messages final_reward = sum(self.step_rewards) # print(f"final_reward: {final_reward}, terminate_reason: {terminate_reason}") - experience = self.process_messages_to_experience( + experience = await self.process_messages_to_experience_async( messages=messages, - reward=final_reward, + reward=float(final_reward), info={ "env_steps": self.step_count, "env_done": 1 if self.done else 0, diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index c9d9cdc684..beefc55295 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -107,7 +107,7 @@ async def generate_env_inference_samples(self, env, rollout_num) -> List[Experie if done: break final_reward = final_reward / 100.0 - experience = self.process_messages_to_experience( + experience = await self.process_messages_to_experience_async( memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds}, diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 7514965eba..e5d48dd9c1 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -258,9 +258,9 @@ async def generate_env_inference_samples( final_reward = 0 else: final_reward = -0.1 - experience = self.process_messages_to_experience( + experience = await self.process_messages_to_experience_async( memory, - final_reward, + float(final_reward), {"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0}, ) experience_list.append(experience) diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index cf3b4d449b..88b4d7e716 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -209,6 +209,39 @@ def process_messages_to_experience( ) return experience + async def process_messages_to_experience_async( + self, messages, reward, info={}, truncate_status=None + ) -> Experience: + converted_experience = await self.model.convert_messages_to_experience_async(messages) + + if converted_experience.truncate_status == "response_truncated": + reward = 0.0 + + tokens = converted_experience.tokens + log_probs = converted_experience.logprobs + assert converted_experience.action_mask is not None + generation_mask = converted_experience.action_mask + log_probs = log_probs * generation_mask + + metrics = {} + for k, v in info.items(): + if isinstance(v, float) or isinstance(v, int): + metrics[k] = float(v) + + experience = Experience( + tokens=tokens, + action_mask=generation_mask, + prompt_length=converted_experience.prompt_length, + prompt_text=converted_experience.prompt_text, + response_text=converted_experience.response_text, + truncate_status=converted_experience.truncate_status or truncate_status, + reward=reward, + logprobs=log_probs, + info=info, + metrics=metrics, + ) + return experience + class BaseSimpleWorkflow(Workflow): def __init__( From 41f36cc9a6809e26790bb18bfb8872c3f0d1ac37 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 7 Jan 2026 16:00:16 +0800 Subject: [PATCH 2/2] fix comment --- trinity/common/workflows/workflow.py | 52 ++++++++++++---------------- trinity/trainer/verl/utils.py | 3 +- 2 files changed, 24 insertions(+), 31 deletions(-) diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 88b4d7e716..0c45466a70 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -176,11 +176,20 @@ def set_repeat_times(self, repeat_times, run_id_base): self.repeat_times = repeat_times self.run_id_base = run_id_base - def process_messages_to_experience( - self, messages, reward, info={}, truncate_status=None + def _build_experience_from_converted( + self, converted_experience, reward, info={}, truncate_status=None ) -> Experience: - converted_experience = self.model.convert_messages_to_experience(messages) + """Private helper method to build Experience from converted_experience. + + Args: + converted_experience: The converted experience from the model. + reward: The reward value. + info: Additional info dictionary. + truncate_status: Optional truncate status to override. + Returns: + Experience: The constructed Experience object. + """ if converted_experience.truncate_status == "response_truncated": reward = 0.0 @@ -209,38 +218,21 @@ def process_messages_to_experience( ) return experience + def process_messages_to_experience( + self, messages, reward, info={}, truncate_status=None + ) -> Experience: + converted_experience = self.model.convert_messages_to_experience(messages) + return self._build_experience_from_converted( + converted_experience, reward, info, truncate_status + ) + async def process_messages_to_experience_async( self, messages, reward, info={}, truncate_status=None ) -> Experience: converted_experience = await self.model.convert_messages_to_experience_async(messages) - - if converted_experience.truncate_status == "response_truncated": - reward = 0.0 - - tokens = converted_experience.tokens - log_probs = converted_experience.logprobs - assert converted_experience.action_mask is not None - generation_mask = converted_experience.action_mask - log_probs = log_probs * generation_mask - - metrics = {} - for k, v in info.items(): - if isinstance(v, float) or isinstance(v, int): - metrics[k] = float(v) - - experience = Experience( - tokens=tokens, - action_mask=generation_mask, - prompt_length=converted_experience.prompt_length, - prompt_text=converted_experience.prompt_text, - response_text=converted_experience.response_text, - truncate_status=converted_experience.truncate_status or truncate_status, - reward=reward, - logprobs=log_probs, - info=info, - metrics=metrics, + return self._build_experience_from_converted( + converted_experience, reward, info, truncate_status ) - return experience class BaseSimpleWorkflow(Workflow): diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 640ee2b748..57aa3e0467 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -66,7 +66,8 @@ def to_data_proto( token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32) eos_mask_idx = cumsum.argmax(dim=-1) token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor( - [exp.reward for exp in experiences] + [exp.reward for exp in experiences], + dtype=torch.float32, ) token_level_rewards = token_level_rewards[:, max_prompt_length:] batch_dict.update(