@@ -98,6 +98,7 @@ class AlfworldWorkflow(MultiTurnWorkflow):
9898 """A workflow for alfworld task."""
9999
100100 is_async : bool = True
101+ can_repeat : bool = False
101102
102103 def __init__ (
103104 self ,
@@ -120,39 +121,32 @@ async def get_model_response(self, messages):
120121 async def get_model_response_text (self , messages ):
121122 return (await self .get_model_response (messages ))[0 ].response_text
122123
123- async def generate_env_inference_samples (self , env , rollout_num ) -> List [Experience ]:
124- # TODO: Make this parallel
125- print ("Generating env inference samples..." )
126- experience_list = []
127- for i in range (rollout_num ):
128- observation , info = env .reset ()
129- final_reward = - 0.1
130- memory = []
131- memory .append ({"role" : "system" , "content" : AlfWORLD_SYSTEM_PROMPT })
132- for r in range (self .max_env_steps ):
133- format_obs = format_observation (observation )
134- memory = memory + [{"role" : "user" , "content" : format_obs }]
135- response_text = await self .get_model_response_text (memory )
136- memory .append ({"role" : "assistant" , "content" : response_text })
137- action = parse_action (response_text )
138- observation , reward , done , info = env .step (action )
139- if done :
140- final_reward = reward
141- break
142- experience = self .process_messages_to_experience (
143- memory , final_reward , {"env_rounds" : r , "env_done" : 1 if done else 0 }
144- )
145- experience_list .append (experience )
124+ async def generate_env_inference_samples (self , env ) -> List [Experience ]:
125+ observation , info = env .reset ()
126+ final_reward = - 0.1
127+ memory = []
128+ memory .append ({"role" : "system" , "content" : AlfWORLD_SYSTEM_PROMPT })
129+ for r in range (self .max_env_steps ):
130+ format_obs = format_observation (observation )
131+ memory = memory + [{"role" : "user" , "content" : format_obs }]
132+ response_text = await self .get_model_response_text (memory )
133+ memory .append ({"role" : "assistant" , "content" : response_text })
134+ action = parse_action (response_text )
135+ observation , reward , done , info = env .step (action )
136+ if done :
137+ final_reward = reward
138+ break
139+ experience = self .process_messages_to_experience (
140+ memory , final_reward , {"env_rounds" : r , "env_done" : 1 if done else 0 }
141+ )
146142 # Close the env to save cpu memory
147143 env .close ()
148- return experience_list
144+ return [ experience ]
149145
150146 async def run_async (self ) -> List [Experience ]:
151147 # assume the task_description is the game_file_path generated.
152148 # see Trinity-RFT/examples/grpo_alfworld/get_alfworld_data.py
153149 game_file_path = self .task_desc
154- rollout_n = self .repeat_times
155- # TODO: Make parallel envs
156150 try :
157151 import textworld
158152 import textworld .gym
@@ -179,7 +173,7 @@ def create_environment(game_file):
179173 error_message = f"Error importing AlfworldTWEnv { str (e )} . Please make sure you have installed the alfworld package successfully, following the instructions in https://github.com/alfworld/alfworld"
180174 raise ImportError (error_message )
181175 env = create_environment (game_file_path )
182- return await self .generate_env_inference_samples (env , rollout_n )
176+ return await self .generate_env_inference_samples (env )
183177
184178
185179@WORKFLOWS .register_module ("step_wise_alfworld_workflow" )
0 commit comments