diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py new file mode 100644 index 0000000000..4892abb97f --- /dev/null +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +import json +from typing import List + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow + +SCIWORLD_SYSTEM_PROMPT = """ +You are an agent, you job is to do some scientific experiment in a virtual test-based environments. + +## Notes: +At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the tag and wrap your action with the tag. +You should ALWAYS take one action each step. +DONOT try to interact with the user at anytime. Finish the task by yourself. + +## Action Format: +Below are the available commands you can use: + open OBJ: open a container + close OBJ: close a container + activate OBJ: activate a device + deactivate OBJ: deactivate a device + connect OBJ to OBJ: connect electrical components + disconnect OBJ: disconnect electrical components + use OBJ [on OBJ]: use a device/item + look around: describe the current room + examine OBJ: describe an object in detail + look at OBJ: describe a container's contents + read OBJ: read a note or book + move OBJ to OBJ: move an object to a container + pick up OBJ: move an object to the inventory + pour OBJ into OBJ: pour a liquid into a container + mix OBJ: chemically mix a container + teleport to LOC: teleport to a specific room + focus on OBJ: signal intent on a task object + wait: task no action for 10 steps + wait1: task no action for a step + +For example your output should be like this: + Now I will check the bedroom ... teleport to bedroom +""" + + +def format_observation(observation: str): + return "Observation: \n" + observation + + +def parse_action(response): + try: + # parse the action within the tag + action = response.split("")[1].split("")[0].strip() + return action + except Exception as e: + print("Error parsing action:", e) + return "" + + +@WORKFLOWS.register_module("sciworld_workflow") +class SciWorldWorkflow(MultiTurnWorkflow): + """A workflow for sciworld task.""" + + def __init__(self, model: ModelWrapper, **kwargs): + super().__init__(model) + self.system_prompt = kwargs.get("system_prompt", None) # Unuse here + self.task_desc: str = kwargs.get("task_desc") + self.truth = kwargs.get("truth") # Unuse here + self.reward_fn = kwargs.get("reward_fn") # Unuse here + self.repeat_times = kwargs.get("repeat_times", 1) + self.max_env_steps = 30 # should be less than 100 + + def get_model_response(self, messages): + responses = self.model.chat(messages, repeat_times=1) + return responses + + def get_model_response_text(self, messages): + return self.get_model_response(messages)[0].response_text + + def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: + # TODO: Make this parallel + print("Generating env inference samples...") + golden_rounds = len(env.get_gold_action_sequence()) + experience_list = [] + for i in range(rollout_num): + observation, info = env.reset() + observation = ( + "Task Description: " + str(env.get_task_description()) + "\n" + observation + ) + final_reward = 0.0 + current_reward = 0.0 + memory = [] + memory.append({"role": "system", "content": SCIWORLD_SYSTEM_PROMPT}) + for r in range(self.max_env_steps): + format_obs = format_observation(observation) + memory = memory + [{"role": "user", "content": format_obs}] + response_text = self.get_model_response_text(memory) + memory.append({"role": "assistant", "content": response_text}) + action = parse_action(response_text) + observation, reward, done, info = env.step(action) + current_reward += reward + final_reward = max(current_reward, final_reward) + if done: + break + final_reward = final_reward / 100.0 + experience = self.process_messages_to_experience( + memory, + final_reward, + {"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds}, + ) + experience_list.append(experience) + # Close the env to save cpu memory + env.close() + return experience_list + + def run(self) -> List[Experience]: + # assume the task_description is the json object containing task index and the var_num + # see Trinity-RFT/script/data_prepare/get_scriworld_data.py + task_desc = self.task_desc + task_config = json.loads(task_desc) + + rollout_n = self.repeat_times + # TODO: Make parallel envs + try: + from scienceworld import ScienceWorldEnv + + def create_environment(task_config): + var_num = task_config["var_num"] + task_name = task_config["task_name"] + jar_path = task_config["jar_path"] + simplificationStr = "easy" + env = ScienceWorldEnv("", jar_path, envStepLimit=100) + env.load(task_name, var_num, simplificationStr, generateGoldPath=True) + return env + + except Exception as e: + print("Please make sure you have installed the sciworld package.") + error_message = f"Error importing SciWorldTWEnv {str(e)}. Please make sure you have installed the sciworld package successfully, following the instructions in https://github.com/allenai/ScienceWorld" + raise ImportError(error_message) + env = create_environment(task_config) + return self.generate_env_inference_samples(env, rollout_n)