Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions trinity/common/workflows/envs/sciworld/sciworld_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
import json
from typing import List

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow

SCIWORLD_SYSTEM_PROMPT = """
You are an agent, you job is to do some scientific experiment in a virtual test-based environments.

## Notes:
At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the <think> </think> tag and wrap your action with the <action> </action> tag.
You should ALWAYS take one action each step.
DONOT try to interact with the user at anytime. Finish the task by yourself.

## Action Format:
Below are the available commands you can use:
open OBJ: open a container
close OBJ: close a container
activate OBJ: activate a device
deactivate OBJ: deactivate a device
connect OBJ to OBJ: connect electrical components
disconnect OBJ: disconnect electrical components
use OBJ [on OBJ]: use a device/item
look around: describe the current room
examine OBJ: describe an object in detail
look at OBJ: describe a container's contents
read OBJ: read a note or book
move OBJ to OBJ: move an object to a container
pick up OBJ: move an object to the inventory
pour OBJ into OBJ: pour a liquid into a container
mix OBJ: chemically mix a container
teleport to LOC: teleport to a specific room
focus on OBJ: signal intent on a task object
wait: task no action for 10 steps
wait1: task no action for a step

For example your output should be like this:
<think> Now I will check the bedroom ... </think><action>teleport to bedroom</action>
"""


def format_observation(observation: str):
return "Observation: \n" + observation


def parse_action(response):
try:
# parse the action within the <action> </action> tag
action = response.split("<action>")[1].split("</action>")[0].strip()
return action
except Exception as e:
print("Error parsing action:", e)
return ""


@WORKFLOWS.register_module("sciworld_workflow")
class SciWorldWorkflow(MultiTurnWorkflow):
"""A workflow for sciworld task."""

def __init__(self, model: ModelWrapper, **kwargs):
super().__init__(model)
self.system_prompt = kwargs.get("system_prompt", None) # Unuse here
self.task_desc: str = kwargs.get("task_desc")
self.truth = kwargs.get("truth") # Unuse here
self.reward_fn = kwargs.get("reward_fn") # Unuse here
self.repeat_times = kwargs.get("repeat_times", 1)
self.max_env_steps = 30 # should be less than 100

def get_model_response(self, messages):
responses = self.model.chat(messages, repeat_times=1)
return responses

def get_model_response_text(self, messages):
return self.get_model_response(messages)[0].response_text

def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
# TODO: Make this parallel
print("Generating env inference samples...")
golden_rounds = len(env.get_gold_action_sequence())
experience_list = []
for i in range(rollout_num):
observation, info = env.reset()
observation = (
"Task Description: " + str(env.get_task_description()) + "\n" + observation
)
final_reward = 0.0
current_reward = 0.0
memory = []
memory.append({"role": "system", "content": SCIWORLD_SYSTEM_PROMPT})
for r in range(self.max_env_steps):
format_obs = format_observation(observation)
memory = memory + [{"role": "user", "content": format_obs}]
response_text = self.get_model_response_text(memory)
memory.append({"role": "assistant", "content": response_text})
action = parse_action(response_text)
observation, reward, done, info = env.step(action)
current_reward += reward
final_reward = max(current_reward, final_reward)
if done:
break
final_reward = final_reward / 100.0
experience = self.process_messages_to_experience(
memory,
final_reward,
{"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds},
)
experience_list.append(experience)
# Close the env to save cpu memory
env.close()
return experience_list

def run(self) -> List[Experience]:
# assume the task_description is the json object containing task index and the var_num
# see Trinity-RFT/script/data_prepare/get_scriworld_data.py
task_desc = self.task_desc
task_config = json.loads(task_desc)

rollout_n = self.repeat_times
# TODO: Make parallel envs
try:
from scienceworld import ScienceWorldEnv

def create_environment(task_config):
var_num = task_config["var_num"]
task_name = task_config["task_name"]
jar_path = task_config["jar_path"]
simplificationStr = "easy"
env = ScienceWorldEnv("", jar_path, envStepLimit=100)
env.load(task_name, var_num, simplificationStr, generateGoldPath=True)
return env

except Exception as e:
print("Please make sure you have installed the sciworld package.")
error_message = f"Error importing SciWorldTWEnv {str(e)}. Please make sure you have installed the sciworld package successfully, following the instructions in https://github.com/allenai/ScienceWorld"
raise ImportError(error_message)
env = create_environment(task_config)
return self.generate_env_inference_samples(env, rollout_n)