|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import json |
| 3 | +from typing import List |
| 4 | + |
| 5 | +from trinity.common.experience import Experience |
| 6 | +from trinity.common.models.model import ModelWrapper |
| 7 | +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow |
| 8 | + |
| 9 | +SCIWORLD_SYSTEM_PROMPT = """ |
| 10 | +You are an agent, you job is to do some scientific experiment in a virtual test-based environments. |
| 11 | +
|
| 12 | +## Notes: |
| 13 | +At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the <think> </think> tag and wrap your action with the <action> </action> tag. |
| 14 | +You should ALWAYS take one action each step. |
| 15 | +DONOT try to interact with the user at anytime. Finish the task by yourself. |
| 16 | +
|
| 17 | +## Action Format: |
| 18 | +Below are the available commands you can use: |
| 19 | + open OBJ: open a container |
| 20 | + close OBJ: close a container |
| 21 | + activate OBJ: activate a device |
| 22 | + deactivate OBJ: deactivate a device |
| 23 | + connect OBJ to OBJ: connect electrical components |
| 24 | + disconnect OBJ: disconnect electrical components |
| 25 | + use OBJ [on OBJ]: use a device/item |
| 26 | + look around: describe the current room |
| 27 | + examine OBJ: describe an object in detail |
| 28 | + look at OBJ: describe a container's contents |
| 29 | + read OBJ: read a note or book |
| 30 | + move OBJ to OBJ: move an object to a container |
| 31 | + pick up OBJ: move an object to the inventory |
| 32 | + pour OBJ into OBJ: pour a liquid into a container |
| 33 | + mix OBJ: chemically mix a container |
| 34 | + teleport to LOC: teleport to a specific room |
| 35 | + focus on OBJ: signal intent on a task object |
| 36 | + wait: task no action for 10 steps |
| 37 | + wait1: task no action for a step |
| 38 | +
|
| 39 | +For example your output should be like this: |
| 40 | +<think> Now I will check the bedroom ... </think><action>teleport to bedroom</action> |
| 41 | +""" |
| 42 | + |
| 43 | + |
| 44 | +def format_observation(observation: str): |
| 45 | + return "Observation: \n" + observation |
| 46 | + |
| 47 | + |
| 48 | +def parse_action(response): |
| 49 | + try: |
| 50 | + # parse the action within the <action> </action> tag |
| 51 | + action = response.split("<action>")[1].split("</action>")[0].strip() |
| 52 | + return action |
| 53 | + except Exception as e: |
| 54 | + print("Error parsing action:", e) |
| 55 | + return "" |
| 56 | + |
| 57 | + |
| 58 | +@WORKFLOWS.register_module("sciworld_workflow") |
| 59 | +class SciWorldWorkflow(MultiTurnWorkflow): |
| 60 | + """A workflow for sciworld task.""" |
| 61 | + |
| 62 | + def __init__(self, model: ModelWrapper, **kwargs): |
| 63 | + super().__init__(model) |
| 64 | + self.system_prompt = kwargs.get("system_prompt", None) # Unuse here |
| 65 | + self.task_desc: str = kwargs.get("task_desc") |
| 66 | + self.truth = kwargs.get("truth") # Unuse here |
| 67 | + self.reward_fn = kwargs.get("reward_fn") # Unuse here |
| 68 | + self.repeat_times = kwargs.get("repeat_times", 1) |
| 69 | + self.max_env_steps = 30 # should be less than 100 |
| 70 | + |
| 71 | + def get_model_response(self, messages): |
| 72 | + responses = self.model.chat(messages, repeat_times=1) |
| 73 | + return responses |
| 74 | + |
| 75 | + def get_model_response_text(self, messages): |
| 76 | + return self.get_model_response(messages)[0].response_text |
| 77 | + |
| 78 | + def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: |
| 79 | + # TODO: Make this parallel |
| 80 | + print("Generating env inference samples...") |
| 81 | + golden_rounds = len(env.get_gold_action_sequence()) |
| 82 | + experience_list = [] |
| 83 | + for i in range(rollout_num): |
| 84 | + observation, info = env.reset() |
| 85 | + observation = ( |
| 86 | + "Task Description: " + str(env.get_task_description()) + "\n" + observation |
| 87 | + ) |
| 88 | + final_reward = 0.0 |
| 89 | + current_reward = 0.0 |
| 90 | + memory = [] |
| 91 | + memory.append({"role": "system", "content": SCIWORLD_SYSTEM_PROMPT}) |
| 92 | + for r in range(self.max_env_steps): |
| 93 | + format_obs = format_observation(observation) |
| 94 | + memory = memory + [{"role": "user", "content": format_obs}] |
| 95 | + response_text = self.get_model_response_text(memory) |
| 96 | + memory.append({"role": "assistant", "content": response_text}) |
| 97 | + action = parse_action(response_text) |
| 98 | + observation, reward, done, info = env.step(action) |
| 99 | + current_reward += reward |
| 100 | + final_reward = max(current_reward, final_reward) |
| 101 | + if done: |
| 102 | + break |
| 103 | + final_reward = final_reward / 100.0 |
| 104 | + experience = self.process_messages_to_experience( |
| 105 | + memory, |
| 106 | + final_reward, |
| 107 | + {"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds}, |
| 108 | + ) |
| 109 | + experience_list.append(experience) |
| 110 | + # Close the env to save cpu memory |
| 111 | + env.close() |
| 112 | + return experience_list |
| 113 | + |
| 114 | + def run(self) -> List[Experience]: |
| 115 | + # assume the task_description is the json object containing task index and the var_num |
| 116 | + # see Trinity-RFT/script/data_prepare/get_scriworld_data.py |
| 117 | + task_desc = self.task_desc |
| 118 | + task_config = json.loads(task_desc) |
| 119 | + |
| 120 | + rollout_n = self.repeat_times |
| 121 | + # TODO: Make parallel envs |
| 122 | + try: |
| 123 | + from scienceworld import ScienceWorldEnv |
| 124 | + |
| 125 | + def create_environment(task_config): |
| 126 | + var_num = task_config["var_num"] |
| 127 | + task_name = task_config["task_name"] |
| 128 | + jar_path = task_config["jar_path"] |
| 129 | + simplificationStr = "easy" |
| 130 | + env = ScienceWorldEnv("", jar_path, envStepLimit=100) |
| 131 | + env.load(task_name, var_num, simplificationStr, generateGoldPath=True) |
| 132 | + return env |
| 133 | + |
| 134 | + except Exception as e: |
| 135 | + print("Please make sure you have installed the sciworld package.") |
| 136 | + error_message = f"Error importing SciWorldTWEnv {str(e)}. Please make sure you have installed the sciworld package successfully, following the instructions in https://github.com/allenai/ScienceWorld" |
| 137 | + raise ImportError(error_message) |
| 138 | + env = create_environment(task_config) |
| 139 | + return self.generate_env_inference_samples(env, rollout_n) |
0 commit comments