Skip to content

Commit 90088fb

Browse files
authored
Make Alfworld Rollout Parallel (#393)
1 parent 1de31a6 commit 90088fb

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

trinity/common/workflows/envs/alfworld/alfworld_workflow.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class AlfworldWorkflow(MultiTurnWorkflow):
9898
"""A workflow for alfworld task."""
9999

100100
is_async: bool = True
101+
can_repeat: bool = False
101102

102103
def __init__(
103104
self,
@@ -120,39 +121,32 @@ async def get_model_response(self, messages):
120121
async def get_model_response_text(self, messages):
121122
return (await self.get_model_response(messages))[0].response_text
122123

123-
async def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
124-
# TODO: Make this parallel
125-
print("Generating env inference samples...")
126-
experience_list = []
127-
for i in range(rollout_num):
128-
observation, info = env.reset()
129-
final_reward = -0.1
130-
memory = []
131-
memory.append({"role": "system", "content": AlfWORLD_SYSTEM_PROMPT})
132-
for r in range(self.max_env_steps):
133-
format_obs = format_observation(observation)
134-
memory = memory + [{"role": "user", "content": format_obs}]
135-
response_text = await self.get_model_response_text(memory)
136-
memory.append({"role": "assistant", "content": response_text})
137-
action = parse_action(response_text)
138-
observation, reward, done, info = env.step(action)
139-
if done:
140-
final_reward = reward
141-
break
142-
experience = self.process_messages_to_experience(
143-
memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
144-
)
145-
experience_list.append(experience)
124+
async def generate_env_inference_samples(self, env) -> List[Experience]:
125+
observation, info = env.reset()
126+
final_reward = -0.1
127+
memory = []
128+
memory.append({"role": "system", "content": AlfWORLD_SYSTEM_PROMPT})
129+
for r in range(self.max_env_steps):
130+
format_obs = format_observation(observation)
131+
memory = memory + [{"role": "user", "content": format_obs}]
132+
response_text = await self.get_model_response_text(memory)
133+
memory.append({"role": "assistant", "content": response_text})
134+
action = parse_action(response_text)
135+
observation, reward, done, info = env.step(action)
136+
if done:
137+
final_reward = reward
138+
break
139+
experience = self.process_messages_to_experience(
140+
memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
141+
)
146142
# Close the env to save cpu memory
147143
env.close()
148-
return experience_list
144+
return [experience]
149145

150146
async def run_async(self) -> List[Experience]:
151147
# assume the task_description is the game_file_path generated.
152148
# see Trinity-RFT/examples/grpo_alfworld/get_alfworld_data.py
153149
game_file_path = self.task_desc
154-
rollout_n = self.repeat_times
155-
# TODO: Make parallel envs
156150
try:
157151
import textworld
158152
import textworld.gym
@@ -179,7 +173,7 @@ def create_environment(game_file):
179173
error_message = f"Error importing AlfworldTWEnv {str(e)}. Please make sure you have installed the alfworld package successfully, following the instructions in https://github.com/alfworld/alfworld"
180174
raise ImportError(error_message)
181175
env = create_environment(game_file_path)
182-
return await self.generate_env_inference_samples(env, rollout_n)
176+
return await self.generate_env_inference_samples(env)
183177

184178

185179
@WORKFLOWS.register_module("step_wise_alfworld_workflow")

0 commit comments

Comments
 (0)