Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def run_async(self):
await asyncio.sleep(0.1)
memory.append({"role": "user", "content": content})
memory.append({"role": "assistant", "content": content.upper()})
experience = self.process_messages_to_experience(memory, 0, {})
experience = await self.process_messages_to_experience_async(memory, 0, {})
experience_list.append(experience)
return experience_list

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
generate_default_empty_experience,
get_jinja_env,
parse_response,
process_messages_to_experience,
process_messages_to_experience_async,
validate_trajectory_format,
)
from trinity.common.workflows.workflow import Task, Workflow
Expand Down Expand Up @@ -202,7 +202,7 @@ async def run_async(self) -> List[Experience]:

if reward >= 1 and traj_format_valid:
print("✅ Task completed successfully in the first attempt!")
experience = process_messages_to_experience(
experience = await process_messages_to_experience_async(
self.model, trajectory, info={"success": success, "reward": reward, "steps": steps}
)
return [experience]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
generate_default_empty_experience,
generate_reward_feedback,
parse_response,
process_messages_to_experience,
process_messages_to_experience_async,
save_task_data,
validate_trajectory_format,
)
Expand Down Expand Up @@ -215,9 +215,9 @@ def _should_keep_for_sft(self, second_traj_format_valid: bool, re_explore_info:
or (re_explore_info["efficiency_improved"] and re_explore_info["new_reward"] >= 1.0)
)

def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience:
async def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience:
"""Generate experience from SFT messages"""
return process_messages_to_experience(self.model, sft_messages, info=metrics)
return await process_messages_to_experience_async(self.model, sft_messages, info=metrics)

async def run_async(self) -> List[Experience]:
"""Run the RAFT alfworld workflow and return experiences"""
Expand Down Expand Up @@ -245,7 +245,7 @@ async def run_async(self) -> List[Experience]:
# Handle first attempt success cases
if reward >= 1 and traj_format_valid:
print("✅ Task completed successfully in the first attempt!")
experience = process_messages_to_experience(
experience = await process_messages_to_experience_async(
self.model, trajectory, info={"success": success, "reward": reward, "steps": steps}
)
return [experience]
Expand Down Expand Up @@ -275,7 +275,7 @@ async def run_async(self) -> List[Experience]:
kept_for_sft = self._should_keep_for_sft(second_traj_format_valid, re_explore_info)

if kept_for_sft:
experience = self._generate_experience_from_sft(sft_messages, metrics)
experience = await self._generate_experience_from_sft(sft_messages, metrics)
experiences.append(experience)
print(
f"✅ Generated good training data: orig={reward}, steps={steps}, new={re_explore_info['new_reward']}, new_steps={re_explore_info['new_steps']}"
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/alfworld/RAFT_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ def create_alfworld_environment(game_file):
raise ImportError(error_message)


def process_messages_to_experience(model, messages, info=None) -> Experience:
async def process_messages_to_experience_async(model, messages, info=None) -> Experience:
"""Convert messages to experience for training, with fallback to default empty experience"""
if info is None:
info = {}

try:
converted_experience = model.convert_messages_to_experience(messages)
converted_experience = await model.convert_messages_to_experience_async(messages)

metrics = {}
for k, v in info.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def generate_env_inference_samples(self, env) -> List[Experience]:
if done:
final_reward = reward
break
experience = self.process_messages_to_experience(
experience = await self.process_messages_to_experience_async(
memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
)
# Close the env to save cpu memory
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/frozen_lake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ async def run_async(self) -> List[Experience]:
# Create experience from messages
final_reward = sum(self.step_rewards)
# print(f"final_reward: {final_reward}, terminate_reason: {terminate_reason}")
experience = self.process_messages_to_experience(
experience = await self.process_messages_to_experience_async(
messages=messages,
reward=final_reward,
reward=float(final_reward),
info={
"env_steps": self.step_count,
"env_done": 1 if self.done else 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def generate_env_inference_samples(self, env, rollout_num) -> List[Experie
if done:
break
final_reward = final_reward / 100.0
experience = self.process_messages_to_experience(
experience = await self.process_messages_to_experience_async(
memory,
final_reward,
{"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds},
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/webshop/webshop_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ async def generate_env_inference_samples(
final_reward = 0
else:
final_reward = -0.1
experience = self.process_messages_to_experience(
experience = await self.process_messages_to_experience_async(
memory,
final_reward,
float(final_reward),
{"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0},
)
experience_list.append(experience)
Expand Down
31 changes: 28 additions & 3 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,20 @@ def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base

def process_messages_to_experience(
self, messages, reward, info={}, truncate_status=None
def _build_experience_from_converted(
self, converted_experience, reward, info={}, truncate_status=None
) -> Experience:
converted_experience = self.model.convert_messages_to_experience(messages)
"""Private helper method to build Experience from converted_experience.

Args:
converted_experience: The converted experience from the model.
reward: The reward value.
info: Additional info dictionary.
truncate_status: Optional truncate status to override.

Returns:
Experience: The constructed Experience object.
"""
if converted_experience.truncate_status == "response_truncated":
reward = 0.0

Expand Down Expand Up @@ -209,6 +218,22 @@ def process_messages_to_experience(
)
return experience

def process_messages_to_experience(
self, messages, reward, info={}, truncate_status=None
) -> Experience:
converted_experience = self.model.convert_messages_to_experience(messages)
return self._build_experience_from_converted(
converted_experience, reward, info, truncate_status
)

async def process_messages_to_experience_async(
self, messages, reward, info={}, truncate_status=None
) -> Experience:
converted_experience = await self.model.convert_messages_to_experience_async(messages)
return self._build_experience_from_converted(
converted_experience, reward, info, truncate_status
)


class BaseSimpleWorkflow(Workflow):
def __init__(
Expand Down
3 changes: 2 additions & 1 deletion trinity/trainer/verl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def to_data_proto(
token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32)
eos_mask_idx = cumsum.argmax(dim=-1)
token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor(
[exp.reward for exp in experiences]
[exp.reward for exp in experiences],
dtype=torch.float32,
)
token_level_rewards = token_level_rewards[:, max_prompt_length:]
batch_dict.update(
Expand Down