Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ buffer:
storage_type: file
path: 'scripts/data_prepare/sciworld_data'
format:
prompt_key: 'game_file'
prompt_key: 'task_desc'
rollout_args:
repeat_times: 8
temperature: 1.0
Expand Down
42 changes: 24 additions & 18 deletions trinity/common/workflows/envs/webshop/webshop_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,31 @@ def __init__(
model=model,
task=task,
)
self.max_env_steps = 15
self.reset(task)

# TODO: Make parallel envs
try:
import gym
from web_agent_site.envs import WebAgentTextEnv # noqa: F401
except Exception as e:
print("Please make sure you have installed the web_agent_site package.")
error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop"
raise ImportError(error_message)
print("Making GYM env")
# NOTE: Hosting the env require ~15GB CPU memory.
# If you want easier env, you can set the num_products to 1000 or 100000.
self.env = gym.make(
"WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True
)

@property
def resettable(self):
return True

def reset(self, task: Task):
self.task_desc = task.task_desc or "0"
self.repeat_times = task.rollout_args.repeat_times
self.max_env_steps = 15

def get_model_response(self, messages):
responses = self.model.chat(messages, repeat_times=1)
Expand Down Expand Up @@ -242,26 +264,10 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E
{"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0},
)
experience_list.append(experience)
# Close the env to save cpu memory
env.close()
return experience_list

def run(self) -> List[Experience]:
# assume the task_description is the session_id generated.
session_id = int(self.task_desc)
rollout_n = self.repeat_times
# TODO: Make parallel envs
try:
import gym
from web_agent_site.envs import WebAgentTextEnv # noqa: F401
except Exception as e:
print("Please make sure you have installed the web_agent_site package.")
error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop"
raise ImportError(error_message)
print("Making GYM env")
# NOTE: Hosting the env require ~15GB CPU memory.
# If you want easier env, you can set the num_products to 1000 or 100000.
env = gym.make(
"WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True
)
return self.generate_env_inference_samples(env, session_id, rollout_n)
return self.generate_env_inference_samples(self.env, session_id, rollout_n)
15 changes: 15 additions & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ def __init__(
self.model = model
self.auxiliary_models = auxiliary_models

@property
def resettable(self):
return False

def reset(self, task: Task):
"""Reset the workflow."""
raise NotImplementedError

@abstractmethod
def run(self) -> List[Experience]:
"""Run workflow and return a list of experiences."""
Expand Down Expand Up @@ -147,6 +155,13 @@ def __init__(
model=model,
task=task,
)
self.reset(task)

@property
def resettable(self):
return True

def reset(self, task: Task):
self.format_args = task.format_args
self.system_prompt = task.format_args.system_prompt
self.reply_prefix = task.format_args.reply_prefix
Expand Down
12 changes: 10 additions & 2 deletions trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
).get_openai_client()
self.auxiliary_models.append(api_client)
self.logger = get_logger(__name__)
self.workflow_instance = None

def is_alive(self):
return True
Expand All @@ -63,8 +64,15 @@ def _run_task(self, task: Task) -> List[Experience]:
"""Init workflow from the task and run it."""
if task.workflow is None:
raise ValueError("Workflow is not set in the task.")
workflow = task.to_workflow(self.model_wrapper, self.auxiliary_models)
return workflow.run()
if (
self.workflow_instance is None
or not self.workflow_instance.__class__ == task.workflow
or not self.workflow_instance.resettable
):
self.workflow_instance = task.to_workflow(self.model_wrapper, self.auxiliary_models)
else:
self.workflow_instance.reset(task)
return self.workflow_instance.run()

def run_task(self, task: Task) -> Status:
"""Run the task and return the states."""
Expand Down