diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml index 350b37ba80..43f7a13af3 100644 --- a/examples/grpo_sciworld/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -19,7 +19,7 @@ buffer: storage_type: file path: 'scripts/data_prepare/sciworld_data' format: - prompt_key: 'game_file' + prompt_key: 'task_desc' rollout_args: repeat_times: 8 temperature: 1.0 diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 6b961116d0..0003217e35 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -191,9 +191,31 @@ def __init__( model=model, task=task, ) + self.max_env_steps = 15 + self.reset(task) + + # TODO: Make parallel envs + try: + import gym + from web_agent_site.envs import WebAgentTextEnv # noqa: F401 + except Exception as e: + print("Please make sure you have installed the web_agent_site package.") + error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop" + raise ImportError(error_message) + print("Making GYM env") + # NOTE: Hosting the env require ~15GB CPU memory. + # If you want easier env, you can set the num_products to 1000 or 100000. + self.env = gym.make( + "WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True + ) + + @property + def resettable(self): + return True + + def reset(self, task: Task): self.task_desc = task.task_desc or "0" self.repeat_times = task.rollout_args.repeat_times - self.max_env_steps = 15 def get_model_response(self, messages): responses = self.model.chat(messages, repeat_times=1) @@ -242,26 +264,10 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E {"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0}, ) experience_list.append(experience) - # Close the env to save cpu memory - env.close() return experience_list def run(self) -> List[Experience]: # assume the task_description is the session_id generated. session_id = int(self.task_desc) rollout_n = self.repeat_times - # TODO: Make parallel envs - try: - import gym - from web_agent_site.envs import WebAgentTextEnv # noqa: F401 - except Exception as e: - print("Please make sure you have installed the web_agent_site package.") - error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop" - raise ImportError(error_message) - print("Making GYM env") - # NOTE: Hosting the env require ~15GB CPU memory. - # If you want easier env, you can set the num_products to 1000 or 100000. - env = gym.make( - "WebAgentTextEnv-v0", observation_mode="text_rich", num_products=None, human_goals=True - ) - return self.generate_env_inference_samples(env, session_id, rollout_n) + return self.generate_env_inference_samples(self.env, session_id, rollout_n) diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index d44d9e9813..603bd1ced4 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -77,6 +77,14 @@ def __init__( self.model = model self.auxiliary_models = auxiliary_models + @property + def resettable(self): + return False + + def reset(self, task: Task): + """Reset the workflow.""" + raise NotImplementedError + @abstractmethod def run(self) -> List[Experience]: """Run workflow and return a list of experiences.""" @@ -147,6 +155,13 @@ def __init__( model=model, task=task, ) + self.reset(task) + + @property + def resettable(self): + return True + + def reset(self, task: Task): self.format_args = task.format_args self.system_prompt = task.format_args.system_prompt self.reply_prefix = task.format_args.reply_prefix diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index e60821347a..f5a1c2dc6a 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -55,6 +55,7 @@ def __init__( ).get_openai_client() self.auxiliary_models.append(api_client) self.logger = get_logger(__name__) + self.workflow_instance = None def is_alive(self): return True @@ -63,8 +64,15 @@ def _run_task(self, task: Task) -> List[Experience]: """Init workflow from the task and run it.""" if task.workflow is None: raise ValueError("Workflow is not set in the task.") - workflow = task.to_workflow(self.model_wrapper, self.auxiliary_models) - return workflow.run() + if ( + self.workflow_instance is None + or not self.workflow_instance.__class__ == task.workflow + or not self.workflow_instance.resettable + ): + self.workflow_instance = task.to_workflow(self.model_wrapper, self.auxiliary_models) + else: + self.workflow_instance.reset(task) + return self.workflow_instance.run() def run_task(self, task: Task) -> Status: """Run the task and return the states."""