Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def run_async(self):
await asyncio.sleep(0.1)
memory.append({"role": "user", "content": content})
memory.append({"role": "assistant", "content": content.upper()})
experience = self.process_messages_to_experience(memory, 0, {})
experience = await self.process_messages_to_experience_async(memory, 0, {})
experience_list.append(experience)
return experience_list

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
generate_default_empty_experience,
get_jinja_env,
parse_response,
process_messages_to_experience,
process_messages_to_experience_async,
validate_trajectory_format,
)
from trinity.common.workflows.workflow import Task, Workflow
Expand Down Expand Up @@ -202,7 +202,7 @@ async def run_async(self) -> List[Experience]:

if reward >= 1 and traj_format_valid:
print("✅ Task completed successfully in the first attempt!")
experience = process_messages_to_experience(
experience = await process_messages_to_experience_async(
self.model, trajectory, info={"success": success, "reward": reward, "steps": steps}
)
return [experience]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
generate_default_empty_experience,
generate_reward_feedback,
parse_response,
process_messages_to_experience,
process_messages_to_experience_async,
save_task_data,
validate_trajectory_format,
)
Expand Down Expand Up @@ -215,9 +215,9 @@ def _should_keep_for_sft(self, second_traj_format_valid: bool, re_explore_info:
or (re_explore_info["efficiency_improved"] and re_explore_info["new_reward"] >= 1.0)
)

def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience:
async def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience:
"""Generate experience from SFT messages"""
return process_messages_to_experience(self.model, sft_messages, info=metrics)
return await process_messages_to_experience_async(self.model, sft_messages, info=metrics)

async def run_async(self) -> List[Experience]:
"""Run the RAFT alfworld workflow and return experiences"""
Expand Down Expand Up @@ -245,7 +245,7 @@ async def run_async(self) -> List[Experience]:
# Handle first attempt success cases
if reward >= 1 and traj_format_valid:
print("✅ Task completed successfully in the first attempt!")
experience = process_messages_to_experience(
experience = await process_messages_to_experience_async(
self.model, trajectory, info={"success": success, "reward": reward, "steps": steps}
)
return [experience]
Expand Down Expand Up @@ -275,7 +275,7 @@ async def run_async(self) -> List[Experience]:
kept_for_sft = self._should_keep_for_sft(second_traj_format_valid, re_explore_info)

if kept_for_sft:
experience = self._generate_experience_from_sft(sft_messages, metrics)
experience = await self._generate_experience_from_sft(sft_messages, metrics)
experiences.append(experience)
print(
f"✅ Generated good training data: orig={reward}, steps={steps}, new={re_explore_info['new_reward']}, new_steps={re_explore_info['new_steps']}"
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/alfworld/RAFT_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ def create_alfworld_environment(game_file):
raise ImportError(error_message)


def process_messages_to_experience(model, messages, info=None) -> Experience:
async def process_messages_to_experience_async(model, messages, info=None) -> Experience:
"""Convert messages to experience for training, with fallback to default empty experience"""
if info is None:
info = {}

try:
converted_experience = model.convert_messages_to_experience(messages)
converted_experience = await model.convert_messages_to_experience_async(messages)

metrics = {}
for k, v in info.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def generate_env_inference_samples(self, env) -> List[Experience]:
if done:
final_reward = reward
break
experience = self.process_messages_to_experience(
experience = await self.process_messages_to_experience_async(
memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
)
# Close the env to save cpu memory
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/frozen_lake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ async def run_async(self) -> List[Experience]:
# Create experience from messages
final_reward = sum(self.step_rewards)
# print(f"final_reward: {final_reward}, terminate_reason: {terminate_reason}")
experience = self.process_messages_to_experience(
experience = await self.process_messages_to_experience_async(
messages=messages,
reward=final_reward,
reward=float(final_reward),
info={
"env_steps": self.step_count,
"env_done": 1 if self.done else 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def generate_env_inference_samples(self, env, rollout_num) -> List[Experie
if done:
break
final_reward = final_reward / 100.0
experience = self.process_messages_to_experience(
experience = await self.process_messages_to_experience_async(
memory,
final_reward,
{"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds},
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/webshop/webshop_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ async def generate_env_inference_samples(
final_reward = 0
else:
final_reward = -0.1
experience = self.process_messages_to_experience(
experience = await self.process_messages_to_experience_async(
memory,
final_reward,
float(final_reward),
{"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0},
)
experience_list.append(experience)
Expand Down
33 changes: 33 additions & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,39 @@ def process_messages_to_experience(
)
return experience

async def process_messages_to_experience_async(
self, messages, reward, info={}, truncate_status=None
) -> Experience:
converted_experience = await self.model.convert_messages_to_experience_async(messages)

if converted_experience.truncate_status == "response_truncated":
reward = 0.0

tokens = converted_experience.tokens
log_probs = converted_experience.logprobs
assert converted_experience.action_mask is not None
generation_mask = converted_experience.action_mask
log_probs = log_probs * generation_mask

metrics = {}
for k, v in info.items():
if isinstance(v, float) or isinstance(v, int):
metrics[k] = float(v)

experience = Experience(
tokens=tokens,
action_mask=generation_mask,
prompt_length=converted_experience.prompt_length,
prompt_text=converted_experience.prompt_text,
response_text=converted_experience.response_text,
truncate_status=converted_experience.truncate_status or truncate_status,
reward=reward,
logprobs=log_probs,
info=info,
metrics=metrics,
)
return experience


class BaseSimpleWorkflow(Workflow):
def __init__(
Expand Down