Skip to content

Commit 65f23ab

Browse files
authored
add missing workflow (#11)
1 parent ff919db commit 65f23ab

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# -*- coding: utf-8 -*-
2+
import json
3+
from typing import List
4+
5+
from trinity.common.experience import Experience
6+
from trinity.common.models.model import ModelWrapper
7+
from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow
8+
9+
SCIWORLD_SYSTEM_PROMPT = """
10+
You are an agent, you job is to do some scientific experiment in a virtual test-based environments.
11+
12+
## Notes:
13+
At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the <think> </think> tag and wrap your action with the <action> </action> tag.
14+
You should ALWAYS take one action each step.
15+
DONOT try to interact with the user at anytime. Finish the task by yourself.
16+
17+
## Action Format:
18+
Below are the available commands you can use:
19+
open OBJ: open a container
20+
close OBJ: close a container
21+
activate OBJ: activate a device
22+
deactivate OBJ: deactivate a device
23+
connect OBJ to OBJ: connect electrical components
24+
disconnect OBJ: disconnect electrical components
25+
use OBJ [on OBJ]: use a device/item
26+
look around: describe the current room
27+
examine OBJ: describe an object in detail
28+
look at OBJ: describe a container's contents
29+
read OBJ: read a note or book
30+
move OBJ to OBJ: move an object to a container
31+
pick up OBJ: move an object to the inventory
32+
pour OBJ into OBJ: pour a liquid into a container
33+
mix OBJ: chemically mix a container
34+
teleport to LOC: teleport to a specific room
35+
focus on OBJ: signal intent on a task object
36+
wait: task no action for 10 steps
37+
wait1: task no action for a step
38+
39+
For example your output should be like this:
40+
<think> Now I will check the bedroom ... </think><action>teleport to bedroom</action>
41+
"""
42+
43+
44+
def format_observation(observation: str):
45+
return "Observation: \n" + observation
46+
47+
48+
def parse_action(response):
49+
try:
50+
# parse the action within the <action> </action> tag
51+
action = response.split("<action>")[1].split("</action>")[0].strip()
52+
return action
53+
except Exception as e:
54+
print("Error parsing action:", e)
55+
return ""
56+
57+
58+
@WORKFLOWS.register_module("sciworld_workflow")
59+
class SciWorldWorkflow(MultiTurnWorkflow):
60+
"""A workflow for sciworld task."""
61+
62+
def __init__(self, model: ModelWrapper, **kwargs):
63+
super().__init__(model)
64+
self.system_prompt = kwargs.get("system_prompt", None) # Unuse here
65+
self.task_desc: str = kwargs.get("task_desc")
66+
self.truth = kwargs.get("truth") # Unuse here
67+
self.reward_fn = kwargs.get("reward_fn") # Unuse here
68+
self.repeat_times = kwargs.get("repeat_times", 1)
69+
self.max_env_steps = 30 # should be less than 100
70+
71+
def get_model_response(self, messages):
72+
responses = self.model.chat(messages, repeat_times=1)
73+
return responses
74+
75+
def get_model_response_text(self, messages):
76+
return self.get_model_response(messages)[0].response_text
77+
78+
def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
79+
# TODO: Make this parallel
80+
print("Generating env inference samples...")
81+
golden_rounds = len(env.get_gold_action_sequence())
82+
experience_list = []
83+
for i in range(rollout_num):
84+
observation, info = env.reset()
85+
observation = (
86+
"Task Description: " + str(env.get_task_description()) + "\n" + observation
87+
)
88+
final_reward = 0.0
89+
current_reward = 0.0
90+
memory = []
91+
memory.append({"role": "system", "content": SCIWORLD_SYSTEM_PROMPT})
92+
for r in range(self.max_env_steps):
93+
format_obs = format_observation(observation)
94+
memory = memory + [{"role": "user", "content": format_obs}]
95+
response_text = self.get_model_response_text(memory)
96+
memory.append({"role": "assistant", "content": response_text})
97+
action = parse_action(response_text)
98+
observation, reward, done, info = env.step(action)
99+
current_reward += reward
100+
final_reward = max(current_reward, final_reward)
101+
if done:
102+
break
103+
final_reward = final_reward / 100.0
104+
experience = self.process_messages_to_experience(
105+
memory,
106+
final_reward,
107+
{"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds},
108+
)
109+
experience_list.append(experience)
110+
# Close the env to save cpu memory
111+
env.close()
112+
return experience_list
113+
114+
def run(self) -> List[Experience]:
115+
# assume the task_description is the json object containing task index and the var_num
116+
# see Trinity-RFT/script/data_prepare/get_scriworld_data.py
117+
task_desc = self.task_desc
118+
task_config = json.loads(task_desc)
119+
120+
rollout_n = self.repeat_times
121+
# TODO: Make parallel envs
122+
try:
123+
from scienceworld import ScienceWorldEnv
124+
125+
def create_environment(task_config):
126+
var_num = task_config["var_num"]
127+
task_name = task_config["task_name"]
128+
jar_path = task_config["jar_path"]
129+
simplificationStr = "easy"
130+
env = ScienceWorldEnv("", jar_path, envStepLimit=100)
131+
env.load(task_name, var_num, simplificationStr, generateGoldPath=True)
132+
return env
133+
134+
except Exception as e:
135+
print("Please make sure you have installed the sciworld package.")
136+
error_message = f"Error importing SciWorldTWEnv {str(e)}. Please make sure you have installed the sciworld package successfully, following the instructions in https://github.com/allenai/ScienceWorld"
137+
raise ImportError(error_message)
138+
env = create_environment(task_config)
139+
return self.generate_env_inference_samples(env, rollout_n)

0 commit comments

Comments
 (0)