Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions trinity/common/workflows/envs/alfworld/alfworld_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class AlfworldWorkflow(MultiTurnWorkflow):
"""A workflow for alfworld task."""

is_async: bool = True
can_repeat: bool = False

def __init__(
self,
Expand All @@ -120,39 +121,32 @@ async def get_model_response(self, messages):
async def get_model_response_text(self, messages):
return (await self.get_model_response(messages))[0].response_text

async def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
# TODO: Make this parallel
print("Generating env inference samples...")
experience_list = []
for i in range(rollout_num):
observation, info = env.reset()
final_reward = -0.1
memory = []
memory.append({"role": "system", "content": AlfWORLD_SYSTEM_PROMPT})
for r in range(self.max_env_steps):
format_obs = format_observation(observation)
memory = memory + [{"role": "user", "content": format_obs}]
response_text = await self.get_model_response_text(memory)
memory.append({"role": "assistant", "content": response_text})
action = parse_action(response_text)
observation, reward, done, info = env.step(action)
if done:
final_reward = reward
break
experience = self.process_messages_to_experience(
memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
)
experience_list.append(experience)
async def generate_env_inference_samples(self, env) -> List[Experience]:
observation, info = env.reset()
final_reward = -0.1
memory = []
memory.append({"role": "system", "content": AlfWORLD_SYSTEM_PROMPT})
for r in range(self.max_env_steps):
format_obs = format_observation(observation)
memory = memory + [{"role": "user", "content": format_obs}]
response_text = await self.get_model_response_text(memory)
memory.append({"role": "assistant", "content": response_text})
action = parse_action(response_text)
observation, reward, done, info = env.step(action)
if done:
final_reward = reward
break
experience = self.process_messages_to_experience(
memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
)
# Close the env to save cpu memory
env.close()
return experience_list
return [experience]

async def run_async(self) -> List[Experience]:
# assume the task_description is the game_file_path generated.
# see Trinity-RFT/examples/grpo_alfworld/get_alfworld_data.py
game_file_path = self.task_desc
rollout_n = self.repeat_times
# TODO: Make parallel envs
try:
import textworld
import textworld.gym
Expand All @@ -179,7 +173,7 @@ def create_environment(game_file):
error_message = f"Error importing AlfworldTWEnv {str(e)}. Please make sure you have installed the alfworld package successfully, following the instructions in https://github.com/alfworld/alfworld"
raise ImportError(error_message)
env = create_environment(game_file_path)
return await self.generate_env_inference_samples(env, rollout_n)
return await self.generate_env_inference_samples(env)


@WORKFLOWS.register_module("step_wise_alfworld_workflow")
Expand Down