diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index f722e82ce6..2ae075a68c 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -98,6 +98,7 @@ class AlfworldWorkflow(MultiTurnWorkflow): """A workflow for alfworld task.""" is_async: bool = True + can_repeat: bool = False def __init__( self, @@ -120,39 +121,32 @@ async def get_model_response(self, messages): async def get_model_response_text(self, messages): return (await self.get_model_response(messages))[0].response_text - async def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: - # TODO: Make this parallel - print("Generating env inference samples...") - experience_list = [] - for i in range(rollout_num): - observation, info = env.reset() - final_reward = -0.1 - memory = [] - memory.append({"role": "system", "content": AlfWORLD_SYSTEM_PROMPT}) - for r in range(self.max_env_steps): - format_obs = format_observation(observation) - memory = memory + [{"role": "user", "content": format_obs}] - response_text = await self.get_model_response_text(memory) - memory.append({"role": "assistant", "content": response_text}) - action = parse_action(response_text) - observation, reward, done, info = env.step(action) - if done: - final_reward = reward - break - experience = self.process_messages_to_experience( - memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0} - ) - experience_list.append(experience) + async def generate_env_inference_samples(self, env) -> List[Experience]: + observation, info = env.reset() + final_reward = -0.1 + memory = [] + memory.append({"role": "system", "content": AlfWORLD_SYSTEM_PROMPT}) + for r in range(self.max_env_steps): + format_obs = format_observation(observation) + memory = memory + [{"role": "user", "content": format_obs}] + response_text = await self.get_model_response_text(memory) + memory.append({"role": "assistant", "content": response_text}) + action = parse_action(response_text) + observation, reward, done, info = env.step(action) + if done: + final_reward = reward + break + experience = self.process_messages_to_experience( + memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0} + ) # Close the env to save cpu memory env.close() - return experience_list + return [experience] async def run_async(self) -> List[Experience]: # assume the task_description is the game_file_path generated. # see Trinity-RFT/examples/grpo_alfworld/get_alfworld_data.py game_file_path = self.task_desc - rollout_n = self.repeat_times - # TODO: Make parallel envs try: import textworld import textworld.gym @@ -179,7 +173,7 @@ def create_environment(game_file): error_message = f"Error importing AlfworldTWEnv {str(e)}. Please make sure you have installed the alfworld package successfully, following the instructions in https://github.com/alfworld/alfworld" raise ImportError(error_message) env = create_environment(game_file_path) - return await self.generate_env_inference_samples(env, rollout_n) + return await self.generate_env_inference_samples(env) @WORKFLOWS.register_module("step_wise_alfworld_workflow")