Skip to content

Commit ded6d23

Browse files
authored
Fix process_messages_to_experience in MultiTurnWorkflow (#468)
1 parent 39dd1d4 commit ded6d23

File tree

10 files changed

+46
-20
lines changed

10 files changed

+46
-20
lines changed

tests/explorer/workflow_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ async def run_async(self):
160160
await asyncio.sleep(0.1)
161161
memory.append({"role": "user", "content": content})
162162
memory.append({"role": "assistant", "content": content.upper()})
163-
experience = self.process_messages_to_experience(memory, 0, {})
163+
experience = await self.process_messages_to_experience_async(memory, 0, {})
164164
experience_list.append(experience)
165165
return experience_list
166166

trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
generate_default_empty_experience,
1111
get_jinja_env,
1212
parse_response,
13-
process_messages_to_experience,
13+
process_messages_to_experience_async,
1414
validate_trajectory_format,
1515
)
1616
from trinity.common.workflows.workflow import Task, Workflow
@@ -202,7 +202,7 @@ async def run_async(self) -> List[Experience]:
202202

203203
if reward >= 1 and traj_format_valid:
204204
print("✅ Task completed successfully in the first attempt!")
205-
experience = process_messages_to_experience(
205+
experience = await process_messages_to_experience_async(
206206
self.model, trajectory, info={"success": success, "reward": reward, "steps": steps}
207207
)
208208
return [experience]

trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
generate_default_empty_experience,
1515
generate_reward_feedback,
1616
parse_response,
17-
process_messages_to_experience,
17+
process_messages_to_experience_async,
1818
save_task_data,
1919
validate_trajectory_format,
2020
)
@@ -215,9 +215,9 @@ def _should_keep_for_sft(self, second_traj_format_valid: bool, re_explore_info:
215215
or (re_explore_info["efficiency_improved"] and re_explore_info["new_reward"] >= 1.0)
216216
)
217217

218-
def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience:
218+
async def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience:
219219
"""Generate experience from SFT messages"""
220-
return process_messages_to_experience(self.model, sft_messages, info=metrics)
220+
return await process_messages_to_experience_async(self.model, sft_messages, info=metrics)
221221

222222
async def run_async(self) -> List[Experience]:
223223
"""Run the RAFT alfworld workflow and return experiences"""
@@ -245,7 +245,7 @@ async def run_async(self) -> List[Experience]:
245245
# Handle first attempt success cases
246246
if reward >= 1 and traj_format_valid:
247247
print("✅ Task completed successfully in the first attempt!")
248-
experience = process_messages_to_experience(
248+
experience = await process_messages_to_experience_async(
249249
self.model, trajectory, info={"success": success, "reward": reward, "steps": steps}
250250
)
251251
return [experience]
@@ -275,7 +275,7 @@ async def run_async(self) -> List[Experience]:
275275
kept_for_sft = self._should_keep_for_sft(second_traj_format_valid, re_explore_info)
276276

277277
if kept_for_sft:
278-
experience = self._generate_experience_from_sft(sft_messages, metrics)
278+
experience = await self._generate_experience_from_sft(sft_messages, metrics)
279279
experiences.append(experience)
280280
print(
281281
f"✅ Generated good training data: orig={reward}, steps={steps}, new={re_explore_info['new_reward']}, new_steps={re_explore_info['new_steps']}"

trinity/common/workflows/envs/alfworld/RAFT_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,13 @@ def create_alfworld_environment(game_file):
107107
raise ImportError(error_message)
108108

109109

110-
def process_messages_to_experience(model, messages, info=None) -> Experience:
110+
async def process_messages_to_experience_async(model, messages, info=None) -> Experience:
111111
"""Convert messages to experience for training, with fallback to default empty experience"""
112112
if info is None:
113113
info = {}
114114

115115
try:
116-
converted_experience = model.convert_messages_to_experience(messages)
116+
converted_experience = await model.convert_messages_to_experience_async(messages)
117117

118118
metrics = {}
119119
for k, v in info.items():

trinity/common/workflows/envs/alfworld/alfworld_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ async def generate_env_inference_samples(self, env) -> List[Experience]:
135135
if done:
136136
final_reward = reward
137137
break
138-
experience = self.process_messages_to_experience(
138+
experience = await self.process_messages_to_experience_async(
139139
memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
140140
)
141141
# Close the env to save cpu memory

trinity/common/workflows/envs/frozen_lake/workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,9 @@ async def run_async(self) -> List[Experience]:
353353
# Create experience from messages
354354
final_reward = sum(self.step_rewards)
355355
# print(f"final_reward: {final_reward}, terminate_reason: {terminate_reason}")
356-
experience = self.process_messages_to_experience(
356+
experience = await self.process_messages_to_experience_async(
357357
messages=messages,
358-
reward=final_reward,
358+
reward=float(final_reward),
359359
info={
360360
"env_steps": self.step_count,
361361
"env_done": 1 if self.done else 0,

trinity/common/workflows/envs/sciworld/sciworld_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ async def generate_env_inference_samples(self, env, rollout_num) -> List[Experie
107107
if done:
108108
break
109109
final_reward = final_reward / 100.0
110-
experience = self.process_messages_to_experience(
110+
experience = await self.process_messages_to_experience_async(
111111
memory,
112112
final_reward,
113113
{"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds},

trinity/common/workflows/envs/webshop/webshop_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,9 @@ async def generate_env_inference_samples(
258258
final_reward = 0
259259
else:
260260
final_reward = -0.1
261-
experience = self.process_messages_to_experience(
261+
experience = await self.process_messages_to_experience_async(
262262
memory,
263-
final_reward,
263+
float(final_reward),
264264
{"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0},
265265
)
266266
experience_list.append(experience)

trinity/common/workflows/workflow.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,20 @@ def set_repeat_times(self, repeat_times, run_id_base):
176176
self.repeat_times = repeat_times
177177
self.run_id_base = run_id_base
178178

179-
def process_messages_to_experience(
180-
self, messages, reward, info={}, truncate_status=None
179+
def _build_experience_from_converted(
180+
self, converted_experience, reward, info={}, truncate_status=None
181181
) -> Experience:
182-
converted_experience = self.model.convert_messages_to_experience(messages)
182+
"""Private helper method to build Experience from converted_experience.
183+
184+
Args:
185+
converted_experience: The converted experience from the model.
186+
reward: The reward value.
187+
info: Additional info dictionary.
188+
truncate_status: Optional truncate status to override.
183189
190+
Returns:
191+
Experience: The constructed Experience object.
192+
"""
184193
if converted_experience.truncate_status == "response_truncated":
185194
reward = 0.0
186195

@@ -209,6 +218,22 @@ def process_messages_to_experience(
209218
)
210219
return experience
211220

221+
def process_messages_to_experience(
222+
self, messages, reward, info={}, truncate_status=None
223+
) -> Experience:
224+
converted_experience = self.model.convert_messages_to_experience(messages)
225+
return self._build_experience_from_converted(
226+
converted_experience, reward, info, truncate_status
227+
)
228+
229+
async def process_messages_to_experience_async(
230+
self, messages, reward, info={}, truncate_status=None
231+
) -> Experience:
232+
converted_experience = await self.model.convert_messages_to_experience_async(messages)
233+
return self._build_experience_from_converted(
234+
converted_experience, reward, info, truncate_status
235+
)
236+
212237

213238
class BaseSimpleWorkflow(Workflow):
214239
def __init__(

trinity/trainer/verl/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def to_data_proto(
6666
token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32)
6767
eos_mask_idx = cumsum.argmax(dim=-1)
6868
token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor(
69-
[exp.reward for exp in experiences]
69+
[exp.reward for exp in experiences],
70+
dtype=torch.float32,
7071
)
7172
token_level_rewards = token_level_rewards[:, max_prompt_length:]
7273
batch_dict.update(

0 commit comments

Comments
 (0)