Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ buffer:
storage_type: file
path: 'scripts/data_prepare/sciworld_data'
format:
prompt_key: 'game_file'
prompt_key: 'task_desc'
rollout_args:
repeat_times: 8
temperature: 1.0
Expand Down
42 changes: 24 additions & 18 deletions trinity/common/workflows/envs/webshop/webshop_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,31 @@ def __init__(
model=model,
task=task,
)
self.max_env_steps = 15
self.reset(task)

# TODO: Make parallel envs
try:
import gym
from web_agent_site.envs import WebAgentTextEnv # noqa: F401
except Exception as e:
print("Please make sure you have installed the web_agent_site package.")
error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop"
raise ImportError(error_message)
print("Making GYM env")
# NOTE: Hosting the env require ~15GB CPU memory.
# If you want easier env, you can set the num_products to 1000 or 100000.
self.env = gym.make(
"WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True
)

@property
def resettable(self):
return True

def reset(self, task: Task):
self.task_desc = task.task_desc or "0"
self.repeat_times = task.rollout_args.repeat_times
self.max_env_steps = 15

def get_model_response(self, messages):
responses = self.model.chat(messages, repeat_times=1)
Expand Down Expand Up @@ -242,26 +264,10 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E
{"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0},
)
experience_list.append(experience)
# Close the env to save cpu memory
env.close()
return experience_list

def run(self) -> List[Experience]:
# assume the task_description is the session_id generated.
session_id = int(self.task_desc)
rollout_n = self.repeat_times
# TODO: Make parallel envs
try:
import gym
from web_agent_site.envs import WebAgentTextEnv # noqa: F401
except Exception as e:
print("Please make sure you have installed the web_agent_site package.")
error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop"
raise ImportError(error_message)
print("Making GYM env")
# NOTE: Hosting the env require ~15GB CPU memory.
# If you want easier env, you can set the num_products to 1000 or 100000.
env = gym.make(
"WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True
)
return self.generate_env_inference_samples(env, session_id, rollout_n)
return self.generate_env_inference_samples(self.env, session_id, rollout_n)
15 changes: 15 additions & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ def __init__(
self.model = model
self.auxiliary_models = auxiliary_models

@property
def resettable(self):
return False

@abstractmethod
def reset(self, task: Task):
"""Reset the workflow."""

@abstractmethod
def run(self) -> List[Experience]:
"""Run workflow and return a list of experiences."""
Expand Down Expand Up @@ -147,6 +155,13 @@ def __init__(
model=model,
task=task,
)
self.reset(task)

@property
def resettable(self):
return True

def reset(self, task: Task):
self.format_args = task.format_args
self.system_prompt = task.format_args.system_prompt
self.reply_prefix = task.format_args.reply_prefix
Expand Down
12 changes: 10 additions & 2 deletions trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
).get_openai_client()
self.auxiliary_models.append(api_client)
self.logger = get_logger(__name__)
self.workflow_instance = None

def is_alive(self):
return True
Expand All @@ -63,8 +64,15 @@ def _run_task(self, task: Task) -> List[Experience]:
"""Init workflow from the task and run it."""
if task.workflow is None:
raise ValueError("Workflow is not set in the task.")
workflow = task.to_workflow(self.model_wrapper, self.auxiliary_models)
return workflow.run()
if (
self.workflow_instance is None
or not self.workflow_instance.__class__ == task.workflow
or not self.workflow_instance.resettable
):
self.workflow_instance = task.to_workflow(self.model_wrapper, self.auxiliary_models)
else:
self.workflow_instance.reset(task)
return self.workflow_instance.run()

def run_task(self, task: Task) -> Status:
"""Run the task and return the states."""
Expand Down