diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md
index b3686dc3ee..43706f2fba 100644
--- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md
+++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md
@@ -21,7 +21,7 @@ You may refer to their original environment to complete the setup.
### Data Preparation
Our dataset follows the format in Huggingface datasets library, so we should correspondingly convert our env dataset.
-Just run the following command.
+Just check the data preparation scripts and run the following command.
```bash
# For ALFworld env
python scripts/data_prepare/get_alfworld_data.py
@@ -53,18 +53,16 @@ We provide an easy way to allow you build your own environment pipeline by creat
See the `trinity/common/workflows/envs/alfworld/alfworld_workflow.py` as an example on how to construct a multi-round workflow.
-You can interact with environment using the messages format, and call the `self.process_batch_messages` function to transform the messages and rewards into the `experience` we need, and send them to buffer.
+You can interact with environment using the messages format, and call the `self.process_messages_to_experience` function to transform the messages and rewards into the `experience` we need, and send them to buffer.
```python
-class AlfworldWorkflow(Workflow):
+class AlfworldWorkflow(MultiTurnWorkflow):
"""A workflow for alfworld task."""
...
def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
print("Generating env inference samples...")
- all_messages = []
- all_rewards = []
- all_infos = []
+ experience_list = []
for i in range(rollout_num):
observation, info = env.reset()
final_reward = -0.1
@@ -80,14 +78,13 @@ class AlfworldWorkflow(Workflow):
if done:
final_reward = reward
break
- all_infos.append(
- {"env_rounds": r, "env_done": 1 if done else 0}
+ experience = self.process_messages_to_experience(
+ memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
)
- all_messages.append(memory)
- all_rewards.append(final_reward)
+ experience_list.append(experience)
# Close the env to save cpu memory
env.close()
- return self.process_batch_messages(all_messages, all_rewards, all_infos=all_infos)
+ return experience_list
def run(self) -> List[Experience]:
@@ -102,7 +99,7 @@ class AlfworldWorkflow(Workflow):
Also, remember to register your workflow:
```python
@WORKFLOWS.register_module("alfworld_workflow")
-class AlfworldWorkflow(Workflow):
+class AlfworldWorkflow(MultiTurnWorkflow):
"""A workflow for alfworld task."""
...
```
diff --git a/scripts/config/sciworld.yaml b/scripts/config/sciworld.yaml
new file mode 100644
index 0000000000..62a9344cbd
--- /dev/null
+++ b/scripts/config/sciworld.yaml
@@ -0,0 +1,56 @@
+data:
+ total_epoch: 20
+ batch_size: 4
+ dataset_path: 'scripts/data_prepare/sciworld_data'
+ default_workflow_type: 'sciworld_workflow'
+ train_split: 'train'
+ eval_split: ''
+ format_config:
+ prompt_key: 'game_file'
+model:
+ model_path: '/PATH/TO/MODEL/CHECKPOINT/'
+ max_prompt_tokens: 4096
+ max_response_tokens: 16384
+ checkpoint_path: 'checkpoints/sciworld_RFT'
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ max_retry_times: 3
+ max_retry_interval: 1
+ train_dataset:
+ name: sciworld_buffer
+ storage_type: queue
+ algorithm_type: ppo
+ path: 'sqlite:///sciworld.db'
+explorer:
+ engine_type: vllm_async
+ engine_num: 2
+ runner_num: 32
+ tensor_parallel_size: 2
+ enable_prefix_caching: false
+ enforce_eager: true
+ dtype: bfloat16
+ temperature: 1.0
+ top_p: 1.0
+ top_k: -1
+ seed: 42
+ logprobs: 0
+ repeat_times: 8
+ use_ray: false
+ backend: 'nccl'
+ max_pending_requests: 32
+ max_waiting_steps: 4
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefil: true
+synchronizer:
+ sync_method: 'online'
+ sync_iteration_interval: 8
+trainer:
+ trainer_type: 'verl'
+ algorithm_type: ppo
+ trainer_config_path: 'scripts/config/train_sciworld.yaml'
+monitor:
+ cache_root_dir: ""
+ project: "sciworld"
+ name: "sciworld_RFT"
diff --git a/scripts/config/train_sciworld.yaml b/scripts/config/train_sciworld.yaml
new file mode 100644
index 0000000000..a818dcb0c6
--- /dev/null
+++ b/scripts/config/train_sciworld.yaml
@@ -0,0 +1,183 @@
+data:
+ tokenizer: null
+ train_files: train_example.parquet
+ val_files: test_example.parquet
+ prompt_key: prompt
+ max_prompt_length: 4096
+ max_response_length: 16384
+ train_batch_size: 96
+ val_batch_size: null
+ return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
+ return_raw_chat: False
+ shuffle: True
+ filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
+ truncation: error
+ image_key: images
+
+actor_rollout_ref:
+ hybrid_engine: True
+ model:
+ path: /PATH/TO/MODEL/CHECKPOINT/
+ external_lib: null
+ override_config: { }
+ enable_gradient_checkpointing: True
+ use_remove_padding: False
+ actor:
+ strategy: fsdp # This is for backward-compatibility
+ ppo_mini_batch_size: 1536
+ # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu
+ ppo_micro_batch_size_per_gpu: 1
+ use_dynamic_bsz: False
+ ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
+ grad_clip: 1.0
+ clip_ratio: 0.2
+ entropy_coeff: 0.001
+ use_kl_loss: True # True for GRPO
+ kl_loss_coef: 0.001 # for grpo
+ kl_loss_type: low_var_kl # for grpo
+ ppo_epochs: 1
+ shuffle: False
+ ulysses_sequence_parallel_size: 1 # sp size
+ optim:
+ lr: 1e-6
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
+ # min_lr_ratio: null # only useful for warmup with cosine
+ warmup_style: constant # select from constant/cosine
+ total_training_steps: -1 # must be override by program
+ fsdp_config:
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ param_offload: False
+ optimizer_offload: False
+ fsdp_size: -1
+ ref:
+ fsdp_config:
+ param_offload: False
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ # log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu
+ log_prob_micro_batch_size_per_gpu: 1
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
+ rollout:
+ name: vllm
+ temperature: 1.0
+ top_k: -1 # 0 for hf rollout, -1 for vllm rollout
+ top_p: 1
+ use_fire_sampling: False # https://arxiv.org/abs/2410.21236
+ prompt_length: ${data.max_prompt_length} # not use for opensource
+ response_length: ${data.max_response_length}
+ # for vllm rollout
+ dtype: bfloat16 # should align with FSDP
+ gpu_memory_utilization: 0.4
+ ignore_eos: False
+ enforce_eager: True
+ free_cache_engine: True
+ load_format: dummy_dtensor
+ tensor_model_parallel_size: 1
+ max_num_batched_tokens: 8192
+ max_model_len: null
+ max_num_seqs: 1024
+ # log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu
+ log_prob_micro_batch_size_per_gpu: 1
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
+ disable_log_stats: True
+ enable_chunked_prefill: True # could get higher throughput
+ # for hf rollout
+ do_sample: True
+ # number of responses (i.e. num sample times)
+ n: 8 # should be > 1 for grpo; Currently is unused parameter
+
+critic:
+ strategy: fsdp
+ optim:
+ lr: 1e-5
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
+ # min_lr_ratio: null # only useful for warmup with cosine
+ warmup_style: constant # select from constant/cosine
+ total_training_steps: -1 # must be override by program
+ model:
+ path: /PATH/TO/MODEL/CHECKPOINT/
+ tokenizer_path: ${actor_rollout_ref.model.path}
+ override_config: { }
+ external_lib: ${actor_rollout_ref.model.external_lib}
+ enable_gradient_checkpointing: True
+ use_remove_padding: False
+ fsdp_config:
+ param_offload: False
+ optimizer_offload: False
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ fsdp_size: -1
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
+ # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu
+ ppo_micro_batch_size_per_gpu: 1
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
+ forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ ppo_max_token_len_per_gpu: 16384 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
+ ulysses_sequence_parallel_size: 1 # sp size
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
+ shuffle: ${actor_rollout_ref.actor.shuffle}
+ grad_clip: 1.0
+ cliprange_value: 0.5
+
+reward_model:
+ enable: False
+ strategy: fsdp
+ model:
+ input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
+ external_lib: ${actor_rollout_ref.model.external_lib}
+ use_remove_padding: False
+ fsdp_config:
+ min_num_params: 0
+ param_offload: False
+ fsdp_size: -1
+ # micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
+ # micro_batch_size_per_gpu: 2 # set a number
+ # max_length: null
+ ulysses_sequence_parallel_size: 1 # sp size
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
+
+custom_reward_function:
+ path: null
+ name: compute_score
+
+algorithm:
+ gamma: 1.0
+ lam: 1.0
+ adv_estimator: grpo
+ kl_penalty: kl # how to estimate kl divergence
+ kl_ctrl:
+ type: fixed
+ kl_coef: 0.001
+
+trainer:
+ balance_batch: True
+ total_epochs: 15
+ # total_training_steps: null
+ project_name: sciworld
+ experiment_name: sciworld_RFT
+ logger: [ 'wandb' ]
+ val_generations_to_log_to_wandb: 0
+ nnodes: 1
+ n_gpus_per_node: 2
+ save_freq: 1
+ # auto: find the last ckpt to resume. If can't find, start from scratch
+ resume_mode: auto # or auto or resume_path if
+ resume_from_path: False
+ test_freq: 100
+ critic_warmup: 0
+ default_hdfs_dir: null
+ remove_previous_ckpt_in_save: False
+ del_local_ckpt_after_load: False
+ default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
+ val_before_train: False
diff --git a/scripts/data_prepare/get_alfworld_data.py b/scripts/data_prepare/get_alfworld_data.py
index 423bc46b2c..b55a04435a 100644
--- a/scripts/data_prepare/get_alfworld_data.py
+++ b/scripts/data_prepare/get_alfworld_data.py
@@ -39,7 +39,6 @@ def create_dataset_files(output_dir, train_size=1024, test_size=100):
# create dataset_dict
dataset_dict = {"train": train_data, "test": test_data}
- # 保存为jsonl格式
for split, data in dataset_dict.items():
output_file = os.path.join(output_dir, f"{split}.jsonl")
with open(output_file, "w") as f:
diff --git a/scripts/data_prepare/get_sciworld_data.py b/scripts/data_prepare/get_sciworld_data.py
new file mode 100644
index 0000000000..e94624c155
--- /dev/null
+++ b/scripts/data_prepare/get_sciworld_data.py
@@ -0,0 +1,132 @@
+"""
+We use this script to create the huggingface format dataset files for the sciworld dataset.
+NOTE: You need to install the ScienceWorld dataset first: https://github.com/allenai/ScienceWorld
+"""
+import json
+import os
+import random
+
+random.seed(42)
+
+task_variations = {
+ "boil": 30,
+ "melt": 30,
+ "freeze": 30,
+ "change-the-state-of-matter-of": 30,
+ "use-thermometer": 540,
+ "measure-melting-point-known-substance": 436,
+ "measure-melting-point-unknown-substance": 300,
+ "power-component": 20,
+ "power-component-renewable-vs-nonrenewable-energy": 20,
+ "test-conductivity": 900,
+ "test-conductivity-of-unknown-substances": 600,
+ "find-living-thing": 300,
+ "find-non-living-thing": 300,
+ "find-plant": 300,
+ "find-animal": 300,
+ "grow-plant": 126,
+ "grow-fruit": 126,
+ "chemistry-mix": 32,
+ "chemistry-mix-paint-secondary-color": 36,
+ "chemistry-mix-paint-tertiary-color": 36,
+ "lifespan-longest-lived": 125,
+ "lifespan-shortest-lived": 125,
+ "lifespan-longest-lived-then-shortest-lived": 125,
+ "identify-life-stages-1": 14,
+ "identify-life-stages-2": 10,
+ "inclined-plane-determine-angle": 168,
+ "inclined-plane-friction-named-surfaces": 1386,
+ "inclined-plane-friction-unnamed-surfaces": 162,
+ "mendelian-genetics-known-plant": 120,
+ "mendelian-genetics-unknown-plant": 480,
+}
+
+
+def create_dataset_files(output_dir, train_task_names, test_task_names, jar_path, percentage=0.6):
+ # make the output directory
+ os.makedirs(output_dir, exist_ok=True)
+
+ train_data = []
+ test_data = []
+
+ for task_name in train_task_names:
+ total_var = task_variations[task_name]
+ for i in range(int(total_var * percentage)):
+ task_config = {
+ "task_name": task_name,
+ "var_num": i,
+ "jar_path": jar_path,
+ }
+ task_desc = json.dumps(task_config)
+ train_data.append({"task_desc": task_desc, "targe": ""})
+
+ random.shuffle(train_data)
+
+ for task_name in test_task_names:
+ total_var = task_variations[task_name]
+ for i in range(int(total_var * percentage)):
+ task_config = {
+ "task_name": task_name,
+ "var_num": i,
+ "jar_path": jar_path,
+ }
+ task_desc = json.dumps(task_config)
+ test_data.append({"task_desc": task_desc, "targe": ""})
+
+ random.shuffle(test_data)
+
+ # create dataset_dict
+ dataset_dict = {"train": train_data, "test": test_data}
+
+ for split, data in dataset_dict.items():
+ output_file = os.path.join(output_dir, f"{split}.jsonl")
+ with open(output_file, "w") as f:
+ for item in data:
+ f.write(json.dumps(item) + "\n")
+
+ # create dataset_dict.json
+ dataset_info = {
+ "citation": "",
+ "description": "Custom dataset",
+ "splits": {
+ "train": {"name": "train", "num_examples": len(train_data)},
+ "test": {"name": "test", "num_examples": len(test_data)},
+ },
+ }
+
+ with open(os.path.join(output_dir, "dataset_dict.json"), "w") as f:
+ json.dump(dataset_info, f, indent=2)
+
+
+if __name__ == "__main__":
+ # NOTE: Mannually set the jar path here.
+ jar_path = "/your/path/ScienceWorld/scienceworld/scienceworld.jar"
+ # Check if the jar file exists, raise an error if it doesn't exist.
+ if not os.path.exists(jar_path):
+ raise FileNotFoundError(
+ f"JAR file not found at {jar_path}, please set the jar path mannually."
+ )
+
+ current_file_path = os.path.dirname(os.path.abspath(__file__))
+ output_dir = f"{current_file_path}/sciworld_data"
+ train_task_names = [
+ "boil",
+ "melt",
+ "change-the-state-of-matter-of",
+ "use-thermometer",
+ "measure-melting-point-known-substance",
+ "power-component",
+ "test-conductivity",
+ "find-living-thing",
+ "find-plant",
+ "grow-plant",
+ "chemistry-mix",
+ "chemistry-mix-paint-secondary-color",
+ "lifespan-shortest-lived",
+ "identify-life-stages-2",
+ "inclined-plane-determine-angle",
+ "inclined-plane-friction-named-surfaces",
+ "mendelian-genetics-known-plant",
+ ]
+ test_task_names = list(task_variations.keys() - set(train_task_names))
+ create_dataset_files(output_dir, train_task_names, test_task_names, jar_path, percentage=0.5)
diff --git a/scripts/data_prepare/get_webshop_data.py b/scripts/data_prepare/get_webshop_data.py
index 0cd86c1a56..7d9847b707 100644
--- a/scripts/data_prepare/get_webshop_data.py
+++ b/scripts/data_prepare/get_webshop_data.py
@@ -21,7 +21,6 @@ def create_dataset_files(output_dir, train_size=4096, test_size=100):
# create dataset_dict
dataset_dict = {"train": train_data, "test": test_data}
- # 保存为jsonl格式
for split, data in dataset_dict.items():
output_file = os.path.join(output_dir, f"{split}.jsonl")
with open(output_file, "w") as f:
diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py
index 8fe1448d5f..a8bcd886a2 100644
--- a/trinity/common/workflows/__init__.py
+++ b/trinity/common/workflows/__init__.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""Workflow module"""
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
+from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
@@ -10,4 +11,5 @@
"MathWorkflow",
"WebShopWorkflow",
"AlfworldWorkflow",
+ "SciWorldWorkflow",
]
diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
index 0fa90f38e0..170a1971e6 100644
--- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
+++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
@@ -1,12 +1,9 @@
# -*- coding: utf-8 -*-
-import uuid
from typing import List
-import torch
-
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
-from trinity.common.workflows.workflow import WORKFLOWS, Workflow
+from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow
EXAMPLE_PROMPT = """
Observation:
@@ -96,13 +93,13 @@ def parse_action(response):
@WORKFLOWS.register_module("alfworld_workflow")
-class AlfworldWorkflow(Workflow):
+class AlfworldWorkflow(MultiTurnWorkflow):
"""A workflow for alfworld task."""
def __init__(self, model: ModelWrapper, **kwargs):
super().__init__(model)
self.system_prompt = kwargs.get("system_prompt", None) # Unuse here
- self.task_desc: str = kwargs.get("task_desc", "0")
+ self.task_desc: str = kwargs.get("task_desc")
self.truth = kwargs.get("truth") # Unuse here
self.reward_fn = kwargs.get("reward_fn") # Unuse here
self.repeat_times = kwargs.get("repeat_times", 1)
@@ -118,9 +115,7 @@ def get_model_response_text(self, messages):
def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
# TODO: Make this parallel
print("Generating env inference samples...")
- all_messages = []
- all_rewards = []
- all_infos = []
+ experience_list = []
for i in range(rollout_num):
observation, info = env.reset()
final_reward = -0.1
@@ -136,50 +131,12 @@ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
if done:
final_reward = reward
break
- all_infos.append({"env_rounds": r, "env_done": 1 if done else 0})
- all_messages.append(memory)
- all_rewards.append(final_reward)
- # Close the env to save cpu memory
- env.close()
- return self.process_batch_messages(all_messages, all_rewards, all_infos=all_infos)
-
- def process_batch_messages(self, all_messages, all_rewards, all_infos) -> List[Experience]:
- # TODO: What about the max_length for training here? Should we process here?
- new_uid = str(
- uuid.uuid4()
- ) # the new_uid might be redundant here, as we have set the run_id in workflow_runner.run_task()
- batch_tokens = []
- batch_generation_masks = []
- batch_log_probs_tensor = []
- for i, messages in enumerate(all_messages):
- converted_experience = self.model.convert_messages_to_experience(messages)
- tokens = converted_experience.tokens
- log_probs = converted_experience.logprobs
- generation_mask = converted_experience.action_mask
- assert log_probs.shape == generation_mask.shape
- log_probs = log_probs * generation_mask # type: ignore
-
- batch_tokens.append(tokens)
- batch_generation_masks.append(generation_mask)
- batch_log_probs_tensor.append(log_probs)
- # actually, we donot need to perform any padding here
- experience_list = []
- for i, (tokens, gen_mask, log_probs) in enumerate(
- zip(batch_tokens, batch_generation_masks, batch_log_probs_tensor)
- ):
- assert tokens.shape == log_probs.shape
- # set prompt length to the first 1 in the gen_mask
- prompt_length = torch.where(gen_mask == 1)[0][0].item()
- experience = Experience(
- run_id=new_uid,
- tokens=tokens,
- prompt_length=prompt_length,
- action_mask=gen_mask,
- reward=all_rewards[i],
- logprobs=log_probs,
- info=all_infos[i],
+ experience = self.process_messages_to_experience(
+ memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0}
)
experience_list.append(experience)
+ # Close the env to save cpu memory
+ env.close()
return experience_list
def run(self) -> List[Experience]:
diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py
new file mode 100644
index 0000000000..4892abb97f
--- /dev/null
+++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+import json
+from typing import List
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow
+
+SCIWORLD_SYSTEM_PROMPT = """
+You are an agent, you job is to do some scientific experiment in a virtual test-based environments.
+
+## Notes:
+At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the tag and wrap your action with the tag.
+You should ALWAYS take one action each step.
+DONOT try to interact with the user at anytime. Finish the task by yourself.
+
+## Action Format:
+Below are the available commands you can use:
+ open OBJ: open a container
+ close OBJ: close a container
+ activate OBJ: activate a device
+ deactivate OBJ: deactivate a device
+ connect OBJ to OBJ: connect electrical components
+ disconnect OBJ: disconnect electrical components
+ use OBJ [on OBJ]: use a device/item
+ look around: describe the current room
+ examine OBJ: describe an object in detail
+ look at OBJ: describe a container's contents
+ read OBJ: read a note or book
+ move OBJ to OBJ: move an object to a container
+ pick up OBJ: move an object to the inventory
+ pour OBJ into OBJ: pour a liquid into a container
+ mix OBJ: chemically mix a container
+ teleport to LOC: teleport to a specific room
+ focus on OBJ: signal intent on a task object
+ wait: task no action for 10 steps
+ wait1: task no action for a step
+
+For example your output should be like this:
+ Now I will check the bedroom ... teleport to bedroom
+"""
+
+
+def format_observation(observation: str):
+ return "Observation: \n" + observation
+
+
+def parse_action(response):
+ try:
+ # parse the action within the tag
+ action = response.split("")[1].split("")[0].strip()
+ return action
+ except Exception as e:
+ print("Error parsing action:", e)
+ return ""
+
+
+@WORKFLOWS.register_module("sciworld_workflow")
+class SciWorldWorkflow(MultiTurnWorkflow):
+ """A workflow for sciworld task."""
+
+ def __init__(self, model: ModelWrapper, **kwargs):
+ super().__init__(model)
+ self.system_prompt = kwargs.get("system_prompt", None) # Unuse here
+ self.task_desc: str = kwargs.get("task_desc")
+ self.truth = kwargs.get("truth") # Unuse here
+ self.reward_fn = kwargs.get("reward_fn") # Unuse here
+ self.repeat_times = kwargs.get("repeat_times", 1)
+ self.max_env_steps = 30 # should be less than 100
+
+ def get_model_response(self, messages):
+ responses = self.model.chat(messages, repeat_times=1)
+ return responses
+
+ def get_model_response_text(self, messages):
+ return self.get_model_response(messages)[0].response_text
+
+ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
+ # TODO: Make this parallel
+ print("Generating env inference samples...")
+ golden_rounds = len(env.get_gold_action_sequence())
+ experience_list = []
+ for i in range(rollout_num):
+ observation, info = env.reset()
+ observation = (
+ "Task Description: " + str(env.get_task_description()) + "\n" + observation
+ )
+ final_reward = 0.0
+ current_reward = 0.0
+ memory = []
+ memory.append({"role": "system", "content": SCIWORLD_SYSTEM_PROMPT})
+ for r in range(self.max_env_steps):
+ format_obs = format_observation(observation)
+ memory = memory + [{"role": "user", "content": format_obs}]
+ response_text = self.get_model_response_text(memory)
+ memory.append({"role": "assistant", "content": response_text})
+ action = parse_action(response_text)
+ observation, reward, done, info = env.step(action)
+ current_reward += reward
+ final_reward = max(current_reward, final_reward)
+ if done:
+ break
+ final_reward = final_reward / 100.0
+ experience = self.process_messages_to_experience(
+ memory,
+ final_reward,
+ {"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds},
+ )
+ experience_list.append(experience)
+ # Close the env to save cpu memory
+ env.close()
+ return experience_list
+
+ def run(self) -> List[Experience]:
+ # assume the task_description is the json object containing task index and the var_num
+ # see Trinity-RFT/script/data_prepare/get_scriworld_data.py
+ task_desc = self.task_desc
+ task_config = json.loads(task_desc)
+
+ rollout_n = self.repeat_times
+ # TODO: Make parallel envs
+ try:
+ from scienceworld import ScienceWorldEnv
+
+ def create_environment(task_config):
+ var_num = task_config["var_num"]
+ task_name = task_config["task_name"]
+ jar_path = task_config["jar_path"]
+ simplificationStr = "easy"
+ env = ScienceWorldEnv("", jar_path, envStepLimit=100)
+ env.load(task_name, var_num, simplificationStr, generateGoldPath=True)
+ return env
+
+ except Exception as e:
+ print("Please make sure you have installed the sciworld package.")
+ error_message = f"Error importing SciWorldTWEnv {str(e)}. Please make sure you have installed the sciworld package successfully, following the instructions in https://github.com/allenai/ScienceWorld"
+ raise ImportError(error_message)
+ env = create_environment(task_config)
+ return self.generate_env_inference_samples(env, rollout_n)
diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py
index 0e4933fe98..5773f8e6e8 100644
--- a/trinity/common/workflows/envs/webshop/webshop_workflow.py
+++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py
@@ -1,12 +1,9 @@
# -*- coding: utf-8 -*-
-import uuid
from typing import List
-import torch
-
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
-from trinity.common.workflows.workflow import WORKFLOWS, Workflow
+from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow
SPARSE_REWARD = False
@@ -181,7 +178,7 @@ def validate_action(action, available_actions):
@WORKFLOWS.register_module("webshop_workflow")
-class WebShopWorkflow(Workflow):
+class WebShopWorkflow(MultiTurnWorkflow):
"""A workflow for webshop task."""
def __init__(self, model: ModelWrapper, **kwargs):
@@ -197,18 +194,15 @@ def get_model_response(self, messages):
responses = self.model.chat(messages, repeat_times=1)
return responses
- def get_model_response_text(self, messages, model="actor"):
+ def get_model_response_text(self, messages):
return self.get_model_response(messages)[0].response_text
def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[Experience]:
# TODO: Make this parallel
print("Generating env inference samples...")
- all_messages = []
- all_rewards = []
- all_infos = []
+ experience_list = []
for i in range(rollout_num):
env.reset(session=session_id)
-
final_reward = -0.1
observation = env.observation
memory = []
@@ -230,10 +224,6 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E
if done:
final_reward = reward
break
- all_infos.append(
- {"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0}
- )
- all_messages.append(memory)
if SPARSE_REWARD:
if final_reward >= 0.99:
final_reward = 1
@@ -241,48 +231,14 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E
final_reward = 0
else:
final_reward = -0.1
- all_rewards.append(final_reward)
- # Close the env to save cpu memory
- env.close()
- return self.process_batch_messages(all_messages, all_rewards, all_infos=all_infos)
-
- def process_batch_messages(self, all_messages, all_rewards, all_infos) -> List[Experience]:
- # TODO: What about the max_length for training here? Should we process here?
- new_uid = str(
- uuid.uuid4()
- ) # the new_uid might be redundant here, as we have set the run_id in workflow_runner.run_task()
- batch_tokens = []
- batch_generation_masks = []
- batch_log_probs_tensor = []
- for i, messages in enumerate(all_messages):
- converted_experience = self.model.convert_messages_to_experience(messages)
- tokens = converted_experience.tokens
- log_probs = converted_experience.logprobs
- generation_mask = converted_experience.action_mask
- assert log_probs.shape == generation_mask.shape
- log_probs = log_probs * generation_mask # type: ignore
-
- batch_tokens.append(tokens)
- batch_generation_masks.append(generation_mask)
- batch_log_probs_tensor.append(log_probs)
- # actually, we donot need to perform any padding here
- experience_list = []
- for i, (tokens, gen_mask, log_probs) in enumerate(
- zip(batch_tokens, batch_generation_masks, batch_log_probs_tensor)
- ):
- assert tokens.shape == log_probs.shape
- # set prompt length to the first 1 in the gen_mask
- prompt_length = torch.where(gen_mask == 1)[0][0].item()
- experience = Experience(
- run_id=new_uid,
- tokens=tokens,
- prompt_length=prompt_length,
- action_mask=gen_mask,
- reward=all_rewards[i],
- logprobs=log_probs,
- info=all_infos[i],
+ experience = self.process_messages_to_experience(
+ memory,
+ final_reward,
+ {"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0},
)
experience_list.append(experience)
+ # Close the env to save cpu memory
+ env.close()
return experience_list
def run(self) -> List[Experience]:
diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py
index f60412df13..1e82b2dfd3 100644
--- a/trinity/common/workflows/workflow.py
+++ b/trinity/common/workflows/workflow.py
@@ -4,6 +4,8 @@
from abc import ABC, abstractmethod
from typing import List
+import torch
+
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.reward_fn import MathRewardFn
@@ -30,6 +32,48 @@ def run(self) -> List[Experience]:
"""Run workflow and return a list of experiences."""
+class MultiTurnWorkflow(Workflow):
+ """
+ The base workflow class for multi-turn tasks.
+ """
+
+ def __init__(self, model: ModelWrapper, **kwargs):
+ super().__init__(model)
+
+ @abstractmethod
+ def run(self) -> List[Experience]:
+ """Run workflow and return a list of experiences."""
+
+ def process_messages_to_experience(self, messages, reward, info={}) -> Experience:
+ converted_experience = self.model.convert_messages_to_experience(messages)
+
+ tokens = converted_experience.tokens
+ log_probs = converted_experience.logprobs
+ assert converted_experience.action_mask is not None
+ generation_mask = converted_experience.action_mask
+ log_probs = log_probs * generation_mask
+
+ assert tokens.shape == log_probs.shape
+ # set prompt length to the first 1 in the gen_mask
+ prompt_length = torch.where(generation_mask == 1)[0][0].item()
+
+ metrics = {}
+ for k, v in info.items():
+ if isinstance(v, float) or isinstance(v, int):
+ metrics[k] = float(v)
+
+ experience = Experience(
+ tokens=tokens,
+ prompt_length=prompt_length,
+ action_mask=generation_mask,
+ reward=reward,
+ logprobs=log_probs,
+ info=info,
+ metrics=metrics,
+ )
+ return experience
+
+
@WORKFLOWS.register_module("simple_workflow")
class SimpleWorkflow(Workflow):
"""A workflow for simple single-round task."""