Skip to content

Commit 3830dd3

Browse files
authored
Add resettable workflow (#43)
1 parent e3e1ad1 commit 3830dd3

File tree

4 files changed

+50
-21
lines changed

4 files changed

+50
-21
lines changed

examples/grpo_sciworld/sciworld.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ buffer:
1919
storage_type: file
2020
path: 'scripts/data_prepare/sciworld_data'
2121
format:
22-
prompt_key: 'game_file'
22+
prompt_key: 'task_desc'
2323
rollout_args:
2424
repeat_times: 8
2525
temperature: 1.0

trinity/common/workflows/envs/webshop/webshop_workflow.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,31 @@ def __init__(
191191
model=model,
192192
task=task,
193193
)
194+
self.max_env_steps = 15
195+
self.reset(task)
196+
197+
# TODO: Make parallel envs
198+
try:
199+
import gym
200+
from web_agent_site.envs import WebAgentTextEnv # noqa: F401
201+
except Exception as e:
202+
print("Please make sure you have installed the web_agent_site package.")
203+
error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop"
204+
raise ImportError(error_message)
205+
print("Making GYM env")
206+
# NOTE: Hosting the env require ~15GB CPU memory.
207+
# If you want easier env, you can set the num_products to 1000 or 100000.
208+
self.env = gym.make(
209+
"WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True
210+
)
211+
212+
@property
213+
def resettable(self):
214+
return True
215+
216+
def reset(self, task: Task):
194217
self.task_desc = task.task_desc or "0"
195218
self.repeat_times = task.rollout_args.repeat_times
196-
self.max_env_steps = 15
197219

198220
def get_model_response(self, messages):
199221
responses = self.model.chat(messages, repeat_times=1)
@@ -242,26 +264,10 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E
242264
{"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0},
243265
)
244266
experience_list.append(experience)
245-
# Close the env to save cpu memory
246-
env.close()
247267
return experience_list
248268

249269
def run(self) -> List[Experience]:
250270
# assume the task_description is the session_id generated.
251271
session_id = int(self.task_desc)
252272
rollout_n = self.repeat_times
253-
# TODO: Make parallel envs
254-
try:
255-
import gym
256-
from web_agent_site.envs import WebAgentTextEnv # noqa: F401
257-
except Exception as e:
258-
print("Please make sure you have installed the web_agent_site package.")
259-
error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop"
260-
raise ImportError(error_message)
261-
print("Making GYM env")
262-
# NOTE: Hosting the env require ~15GB CPU memory.
263-
# If you want easier env, you can set the num_products to 1000 or 100000.
264-
env = gym.make(
265-
"WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True
266-
)
267-
return self.generate_env_inference_samples(env, session_id, rollout_n)
273+
return self.generate_env_inference_samples(self.env, session_id, rollout_n)

trinity/common/workflows/workflow.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ def __init__(
7777
self.model = model
7878
self.auxiliary_models = auxiliary_models
7979

80+
@property
81+
def resettable(self):
82+
return False
83+
84+
def reset(self, task: Task):
85+
"""Reset the workflow."""
86+
raise NotImplementedError
87+
8088
@abstractmethod
8189
def run(self) -> List[Experience]:
8290
"""Run workflow and return a list of experiences."""
@@ -147,6 +155,13 @@ def __init__(
147155
model=model,
148156
task=task,
149157
)
158+
self.reset(task)
159+
160+
@property
161+
def resettable(self):
162+
return True
163+
164+
def reset(self, task: Task):
150165
self.format_args = task.format_args
151166
self.system_prompt = task.format_args.system_prompt
152167
self.reply_prefix = task.format_args.reply_prefix

trinity/explorer/workflow_runner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
).get_openai_client()
5656
self.auxiliary_models.append(api_client)
5757
self.logger = get_logger(__name__)
58+
self.workflow_instance = None
5859

5960
def is_alive(self):
6061
return True
@@ -63,8 +64,15 @@ def _run_task(self, task: Task) -> List[Experience]:
6364
"""Init workflow from the task and run it."""
6465
if task.workflow is None:
6566
raise ValueError("Workflow is not set in the task.")
66-
workflow = task.to_workflow(self.model_wrapper, self.auxiliary_models)
67-
return workflow.run()
67+
if (
68+
self.workflow_instance is None
69+
or not self.workflow_instance.__class__ == task.workflow
70+
or not self.workflow_instance.resettable
71+
):
72+
self.workflow_instance = task.to_workflow(self.model_wrapper, self.auxiliary_models)
73+
else:
74+
self.workflow_instance.reset(task)
75+
return self.workflow_instance.run()
6876

6977
def run_task(self, task: Task) -> Status:
7078
"""Run the task and return the states."""

0 commit comments

Comments
 (0)