diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md index b3686dc3ee..43706f2fba 100644 --- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md +++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md @@ -21,7 +21,7 @@ You may refer to their original environment to complete the setup. ### Data Preparation Our dataset follows the format in Huggingface datasets library, so we should correspondingly convert our env dataset. -Just run the following command. +Just check the data preparation scripts and run the following command. ```bash # For ALFworld env python scripts/data_prepare/get_alfworld_data.py @@ -53,18 +53,16 @@ We provide an easy way to allow you build your own environment pipeline by creat See the `trinity/common/workflows/envs/alfworld/alfworld_workflow.py` as an example on how to construct a multi-round workflow. -You can interact with environment using the messages format, and call the `self.process_batch_messages` function to transform the messages and rewards into the `experience` we need, and send them to buffer. +You can interact with environment using the messages format, and call the `self.process_messages_to_experience` function to transform the messages and rewards into the `experience` we need, and send them to buffer. ```python -class AlfworldWorkflow(Workflow): +class AlfworldWorkflow(MultiTurnWorkflow): """A workflow for alfworld task.""" ... def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: print("Generating env inference samples...") - all_messages = [] - all_rewards = [] - all_infos = [] + experience_list = [] for i in range(rollout_num): observation, info = env.reset() final_reward = -0.1 @@ -80,14 +78,13 @@ class AlfworldWorkflow(Workflow): if done: final_reward = reward break - all_infos.append( - {"env_rounds": r, "env_done": 1 if done else 0} + experience = self.process_messages_to_experience( + memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0} ) - all_messages.append(memory) - all_rewards.append(final_reward) + experience_list.append(experience) # Close the env to save cpu memory env.close() - return self.process_batch_messages(all_messages, all_rewards, all_infos=all_infos) + return experience_list def run(self) -> List[Experience]: @@ -102,7 +99,7 @@ class AlfworldWorkflow(Workflow): Also, remember to register your workflow: ```python @WORKFLOWS.register_module("alfworld_workflow") -class AlfworldWorkflow(Workflow): +class AlfworldWorkflow(MultiTurnWorkflow): """A workflow for alfworld task.""" ... ``` diff --git a/scripts/config/sciworld.yaml b/scripts/config/sciworld.yaml new file mode 100644 index 0000000000..62a9344cbd --- /dev/null +++ b/scripts/config/sciworld.yaml @@ -0,0 +1,56 @@ +data: + total_epoch: 20 + batch_size: 4 + dataset_path: 'scripts/data_prepare/sciworld_data' + default_workflow_type: 'sciworld_workflow' + train_split: 'train' + eval_split: '' + format_config: + prompt_key: 'game_file' +model: + model_path: '/PATH/TO/MODEL/CHECKPOINT/' + max_prompt_tokens: 4096 + max_response_tokens: 16384 + checkpoint_path: 'checkpoints/sciworld_RFT' +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + max_retry_times: 3 + max_retry_interval: 1 + train_dataset: + name: sciworld_buffer + storage_type: queue + algorithm_type: ppo + path: 'sqlite:///sciworld.db' +explorer: + engine_type: vllm_async + engine_num: 2 + runner_num: 32 + tensor_parallel_size: 2 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + temperature: 1.0 + top_p: 1.0 + top_k: -1 + seed: 42 + logprobs: 0 + repeat_times: 8 + use_ray: false + backend: 'nccl' + max_pending_requests: 32 + max_waiting_steps: 4 + gpu_memory_utilization: 0.7 + enable_chunked_prefil: true +synchronizer: + sync_method: 'online' + sync_iteration_interval: 8 +trainer: + trainer_type: 'verl' + algorithm_type: ppo + trainer_config_path: 'scripts/config/train_sciworld.yaml' +monitor: + cache_root_dir: "" + project: "sciworld" + name: "sciworld_RFT" diff --git a/scripts/config/train_sciworld.yaml b/scripts/config/train_sciworld.yaml new file mode 100644 index 0000000000..a818dcb0c6 --- /dev/null +++ b/scripts/config/train_sciworld.yaml @@ -0,0 +1,183 @@ +data: + tokenizer: null + train_files: train_example.parquet + val_files: test_example.parquet + prompt_key: prompt + max_prompt_length: 4096 + max_response_length: 16384 + train_batch_size: 96 + val_batch_size: null + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + shuffle: True + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left' + truncation: error + image_key: images + +actor_rollout_ref: + hybrid_engine: True + model: + path: /PATH/TO/MODEL/CHECKPOINT/ + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 1536 + # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 1 + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + # log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 1 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + use_fire_sampling: False # https://arxiv.org/abs/2410.21236 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.4 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + # log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 1 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 8 # should be > 1 for grpo; Currently is unused parameter + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: /PATH/TO/MODEL/CHECKPOINT/ + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 1 + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 16384 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + # micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + # micro_batch_size_per_gpu: 2 # set a number + # max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +custom_reward_function: + path: null + name: compute_score + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + balance_batch: True + total_epochs: 15 + # total_training_steps: null + project_name: sciworld + experiment_name: sciworld_RFT + logger: [ 'wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 2 + save_freq: 1 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: 100 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + val_before_train: False diff --git a/scripts/data_prepare/get_alfworld_data.py b/scripts/data_prepare/get_alfworld_data.py index 423bc46b2c..b55a04435a 100644 --- a/scripts/data_prepare/get_alfworld_data.py +++ b/scripts/data_prepare/get_alfworld_data.py @@ -39,7 +39,6 @@ def create_dataset_files(output_dir, train_size=1024, test_size=100): # create dataset_dict dataset_dict = {"train": train_data, "test": test_data} - # 保存为jsonl格式 for split, data in dataset_dict.items(): output_file = os.path.join(output_dir, f"{split}.jsonl") with open(output_file, "w") as f: diff --git a/scripts/data_prepare/get_sciworld_data.py b/scripts/data_prepare/get_sciworld_data.py new file mode 100644 index 0000000000..e94624c155 --- /dev/null +++ b/scripts/data_prepare/get_sciworld_data.py @@ -0,0 +1,132 @@ +""" +We use this script to create the huggingface format dataset files for the sciworld dataset. +NOTE: You need to install the ScienceWorld dataset first: https://github.com/allenai/ScienceWorld +""" +import json +import os +import random + +random.seed(42) + +task_variations = { + "boil": 30, + "melt": 30, + "freeze": 30, + "change-the-state-of-matter-of": 30, + "use-thermometer": 540, + "measure-melting-point-known-substance": 436, + "measure-melting-point-unknown-substance": 300, + "power-component": 20, + "power-component-renewable-vs-nonrenewable-energy": 20, + "test-conductivity": 900, + "test-conductivity-of-unknown-substances": 600, + "find-living-thing": 300, + "find-non-living-thing": 300, + "find-plant": 300, + "find-animal": 300, + "grow-plant": 126, + "grow-fruit": 126, + "chemistry-mix": 32, + "chemistry-mix-paint-secondary-color": 36, + "chemistry-mix-paint-tertiary-color": 36, + "lifespan-longest-lived": 125, + "lifespan-shortest-lived": 125, + "lifespan-longest-lived-then-shortest-lived": 125, + "identify-life-stages-1": 14, + "identify-life-stages-2": 10, + "inclined-plane-determine-angle": 168, + "inclined-plane-friction-named-surfaces": 1386, + "inclined-plane-friction-unnamed-surfaces": 162, + "mendelian-genetics-known-plant": 120, + "mendelian-genetics-unknown-plant": 480, +} + + +def create_dataset_files(output_dir, train_task_names, test_task_names, jar_path, percentage=0.6): + # make the output directory + os.makedirs(output_dir, exist_ok=True) + + train_data = [] + test_data = [] + + for task_name in train_task_names: + total_var = task_variations[task_name] + for i in range(int(total_var * percentage)): + task_config = { + "task_name": task_name, + "var_num": i, + "jar_path": jar_path, + } + task_desc = json.dumps(task_config) + train_data.append({"task_desc": task_desc, "targe": ""}) + + random.shuffle(train_data) + + for task_name in test_task_names: + total_var = task_variations[task_name] + for i in range(int(total_var * percentage)): + task_config = { + "task_name": task_name, + "var_num": i, + "jar_path": jar_path, + } + task_desc = json.dumps(task_config) + test_data.append({"task_desc": task_desc, "targe": ""}) + + random.shuffle(test_data) + + # create dataset_dict + dataset_dict = {"train": train_data, "test": test_data} + + for split, data in dataset_dict.items(): + output_file = os.path.join(output_dir, f"{split}.jsonl") + with open(output_file, "w") as f: + for item in data: + f.write(json.dumps(item) + "\n") + + # create dataset_dict.json + dataset_info = { + "citation": "", + "description": "Custom dataset", + "splits": { + "train": {"name": "train", "num_examples": len(train_data)}, + "test": {"name": "test", "num_examples": len(test_data)}, + }, + } + + with open(os.path.join(output_dir, "dataset_dict.json"), "w") as f: + json.dump(dataset_info, f, indent=2) + + +if __name__ == "__main__": + # NOTE: Mannually set the jar path here. + jar_path = "/your/path/ScienceWorld/scienceworld/scienceworld.jar" + # Check if the jar file exists, raise an error if it doesn't exist. + if not os.path.exists(jar_path): + raise FileNotFoundError( + f"JAR file not found at {jar_path}, please set the jar path mannually." + ) + + current_file_path = os.path.dirname(os.path.abspath(__file__)) + output_dir = f"{current_file_path}/sciworld_data" + train_task_names = [ + "boil", + "melt", + "change-the-state-of-matter-of", + "use-thermometer", + "measure-melting-point-known-substance", + "power-component", + "test-conductivity", + "find-living-thing", + "find-plant", + "grow-plant", + "chemistry-mix", + "chemistry-mix-paint-secondary-color", + "lifespan-shortest-lived", + "identify-life-stages-2", + "inclined-plane-determine-angle", + "inclined-plane-friction-named-surfaces", + "mendelian-genetics-known-plant", + ] + test_task_names = list(task_variations.keys() - set(train_task_names)) + create_dataset_files(output_dir, train_task_names, test_task_names, jar_path, percentage=0.5) diff --git a/scripts/data_prepare/get_webshop_data.py b/scripts/data_prepare/get_webshop_data.py index 0cd86c1a56..7d9847b707 100644 --- a/scripts/data_prepare/get_webshop_data.py +++ b/scripts/data_prepare/get_webshop_data.py @@ -21,7 +21,6 @@ def create_dataset_files(output_dir, train_size=4096, test_size=100): # create dataset_dict dataset_dict = {"train": train_data, "test": test_data} - # 保存为jsonl格式 for split, data in dataset_dict.items(): output_file = os.path.join(output_dir, f"{split}.jsonl") with open(output_file, "w") as f: diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 8fe1448d5f..a8bcd886a2 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Workflow module""" from .envs.alfworld.alfworld_workflow import AlfworldWorkflow +from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow @@ -10,4 +11,5 @@ "MathWorkflow", "WebShopWorkflow", "AlfworldWorkflow", + "SciWorldWorkflow", ] diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 0fa90f38e0..170a1971e6 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- -import uuid from typing import List -import torch - from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, Workflow +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow EXAMPLE_PROMPT = """ Observation: @@ -96,13 +93,13 @@ def parse_action(response): @WORKFLOWS.register_module("alfworld_workflow") -class AlfworldWorkflow(Workflow): +class AlfworldWorkflow(MultiTurnWorkflow): """A workflow for alfworld task.""" def __init__(self, model: ModelWrapper, **kwargs): super().__init__(model) self.system_prompt = kwargs.get("system_prompt", None) # Unuse here - self.task_desc: str = kwargs.get("task_desc", "0") + self.task_desc: str = kwargs.get("task_desc") self.truth = kwargs.get("truth") # Unuse here self.reward_fn = kwargs.get("reward_fn") # Unuse here self.repeat_times = kwargs.get("repeat_times", 1) @@ -118,9 +115,7 @@ def get_model_response_text(self, messages): def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: # TODO: Make this parallel print("Generating env inference samples...") - all_messages = [] - all_rewards = [] - all_infos = [] + experience_list = [] for i in range(rollout_num): observation, info = env.reset() final_reward = -0.1 @@ -136,50 +131,12 @@ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: if done: final_reward = reward break - all_infos.append({"env_rounds": r, "env_done": 1 if done else 0}) - all_messages.append(memory) - all_rewards.append(final_reward) - # Close the env to save cpu memory - env.close() - return self.process_batch_messages(all_messages, all_rewards, all_infos=all_infos) - - def process_batch_messages(self, all_messages, all_rewards, all_infos) -> List[Experience]: - # TODO: What about the max_length for training here? Should we process here? - new_uid = str( - uuid.uuid4() - ) # the new_uid might be redundant here, as we have set the run_id in workflow_runner.run_task() - batch_tokens = [] - batch_generation_masks = [] - batch_log_probs_tensor = [] - for i, messages in enumerate(all_messages): - converted_experience = self.model.convert_messages_to_experience(messages) - tokens = converted_experience.tokens - log_probs = converted_experience.logprobs - generation_mask = converted_experience.action_mask - assert log_probs.shape == generation_mask.shape - log_probs = log_probs * generation_mask # type: ignore - - batch_tokens.append(tokens) - batch_generation_masks.append(generation_mask) - batch_log_probs_tensor.append(log_probs) - # actually, we donot need to perform any padding here - experience_list = [] - for i, (tokens, gen_mask, log_probs) in enumerate( - zip(batch_tokens, batch_generation_masks, batch_log_probs_tensor) - ): - assert tokens.shape == log_probs.shape - # set prompt length to the first 1 in the gen_mask - prompt_length = torch.where(gen_mask == 1)[0][0].item() - experience = Experience( - run_id=new_uid, - tokens=tokens, - prompt_length=prompt_length, - action_mask=gen_mask, - reward=all_rewards[i], - logprobs=log_probs, - info=all_infos[i], + experience = self.process_messages_to_experience( + memory, final_reward, {"env_rounds": r, "env_done": 1 if done else 0} ) experience_list.append(experience) + # Close the env to save cpu memory + env.close() return experience_list def run(self) -> List[Experience]: diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py new file mode 100644 index 0000000000..4892abb97f --- /dev/null +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +import json +from typing import List + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow + +SCIWORLD_SYSTEM_PROMPT = """ +You are an agent, you job is to do some scientific experiment in a virtual test-based environments. + +## Notes: +At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the tag and wrap your action with the tag. +You should ALWAYS take one action each step. +DONOT try to interact with the user at anytime. Finish the task by yourself. + +## Action Format: +Below are the available commands you can use: + open OBJ: open a container + close OBJ: close a container + activate OBJ: activate a device + deactivate OBJ: deactivate a device + connect OBJ to OBJ: connect electrical components + disconnect OBJ: disconnect electrical components + use OBJ [on OBJ]: use a device/item + look around: describe the current room + examine OBJ: describe an object in detail + look at OBJ: describe a container's contents + read OBJ: read a note or book + move OBJ to OBJ: move an object to a container + pick up OBJ: move an object to the inventory + pour OBJ into OBJ: pour a liquid into a container + mix OBJ: chemically mix a container + teleport to LOC: teleport to a specific room + focus on OBJ: signal intent on a task object + wait: task no action for 10 steps + wait1: task no action for a step + +For example your output should be like this: + Now I will check the bedroom ... teleport to bedroom +""" + + +def format_observation(observation: str): + return "Observation: \n" + observation + + +def parse_action(response): + try: + # parse the action within the tag + action = response.split("")[1].split("")[0].strip() + return action + except Exception as e: + print("Error parsing action:", e) + return "" + + +@WORKFLOWS.register_module("sciworld_workflow") +class SciWorldWorkflow(MultiTurnWorkflow): + """A workflow for sciworld task.""" + + def __init__(self, model: ModelWrapper, **kwargs): + super().__init__(model) + self.system_prompt = kwargs.get("system_prompt", None) # Unuse here + self.task_desc: str = kwargs.get("task_desc") + self.truth = kwargs.get("truth") # Unuse here + self.reward_fn = kwargs.get("reward_fn") # Unuse here + self.repeat_times = kwargs.get("repeat_times", 1) + self.max_env_steps = 30 # should be less than 100 + + def get_model_response(self, messages): + responses = self.model.chat(messages, repeat_times=1) + return responses + + def get_model_response_text(self, messages): + return self.get_model_response(messages)[0].response_text + + def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]: + # TODO: Make this parallel + print("Generating env inference samples...") + golden_rounds = len(env.get_gold_action_sequence()) + experience_list = [] + for i in range(rollout_num): + observation, info = env.reset() + observation = ( + "Task Description: " + str(env.get_task_description()) + "\n" + observation + ) + final_reward = 0.0 + current_reward = 0.0 + memory = [] + memory.append({"role": "system", "content": SCIWORLD_SYSTEM_PROMPT}) + for r in range(self.max_env_steps): + format_obs = format_observation(observation) + memory = memory + [{"role": "user", "content": format_obs}] + response_text = self.get_model_response_text(memory) + memory.append({"role": "assistant", "content": response_text}) + action = parse_action(response_text) + observation, reward, done, info = env.step(action) + current_reward += reward + final_reward = max(current_reward, final_reward) + if done: + break + final_reward = final_reward / 100.0 + experience = self.process_messages_to_experience( + memory, + final_reward, + {"env_rounds": r, "env_done": 1 if done else 0, "golden_rounds": golden_rounds}, + ) + experience_list.append(experience) + # Close the env to save cpu memory + env.close() + return experience_list + + def run(self) -> List[Experience]: + # assume the task_description is the json object containing task index and the var_num + # see Trinity-RFT/script/data_prepare/get_scriworld_data.py + task_desc = self.task_desc + task_config = json.loads(task_desc) + + rollout_n = self.repeat_times + # TODO: Make parallel envs + try: + from scienceworld import ScienceWorldEnv + + def create_environment(task_config): + var_num = task_config["var_num"] + task_name = task_config["task_name"] + jar_path = task_config["jar_path"] + simplificationStr = "easy" + env = ScienceWorldEnv("", jar_path, envStepLimit=100) + env.load(task_name, var_num, simplificationStr, generateGoldPath=True) + return env + + except Exception as e: + print("Please make sure you have installed the sciworld package.") + error_message = f"Error importing SciWorldTWEnv {str(e)}. Please make sure you have installed the sciworld package successfully, following the instructions in https://github.com/allenai/ScienceWorld" + raise ImportError(error_message) + env = create_environment(task_config) + return self.generate_env_inference_samples(env, rollout_n) diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 0e4933fe98..5773f8e6e8 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- -import uuid from typing import List -import torch - from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, Workflow +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow SPARSE_REWARD = False @@ -181,7 +178,7 @@ def validate_action(action, available_actions): @WORKFLOWS.register_module("webshop_workflow") -class WebShopWorkflow(Workflow): +class WebShopWorkflow(MultiTurnWorkflow): """A workflow for webshop task.""" def __init__(self, model: ModelWrapper, **kwargs): @@ -197,18 +194,15 @@ def get_model_response(self, messages): responses = self.model.chat(messages, repeat_times=1) return responses - def get_model_response_text(self, messages, model="actor"): + def get_model_response_text(self, messages): return self.get_model_response(messages)[0].response_text def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[Experience]: # TODO: Make this parallel print("Generating env inference samples...") - all_messages = [] - all_rewards = [] - all_infos = [] + experience_list = [] for i in range(rollout_num): env.reset(session=session_id) - final_reward = -0.1 observation = env.observation memory = [] @@ -230,10 +224,6 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E if done: final_reward = reward break - all_infos.append( - {"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0} - ) - all_messages.append(memory) if SPARSE_REWARD: if final_reward >= 0.99: final_reward = 1 @@ -241,48 +231,14 @@ def generate_env_inference_samples(self, env, session_id, rollout_num) -> List[E final_reward = 0 else: final_reward = -0.1 - all_rewards.append(final_reward) - # Close the env to save cpu memory - env.close() - return self.process_batch_messages(all_messages, all_rewards, all_infos=all_infos) - - def process_batch_messages(self, all_messages, all_rewards, all_infos) -> List[Experience]: - # TODO: What about the max_length for training here? Should we process here? - new_uid = str( - uuid.uuid4() - ) # the new_uid might be redundant here, as we have set the run_id in workflow_runner.run_task() - batch_tokens = [] - batch_generation_masks = [] - batch_log_probs_tensor = [] - for i, messages in enumerate(all_messages): - converted_experience = self.model.convert_messages_to_experience(messages) - tokens = converted_experience.tokens - log_probs = converted_experience.logprobs - generation_mask = converted_experience.action_mask - assert log_probs.shape == generation_mask.shape - log_probs = log_probs * generation_mask # type: ignore - - batch_tokens.append(tokens) - batch_generation_masks.append(generation_mask) - batch_log_probs_tensor.append(log_probs) - # actually, we donot need to perform any padding here - experience_list = [] - for i, (tokens, gen_mask, log_probs) in enumerate( - zip(batch_tokens, batch_generation_masks, batch_log_probs_tensor) - ): - assert tokens.shape == log_probs.shape - # set prompt length to the first 1 in the gen_mask - prompt_length = torch.where(gen_mask == 1)[0][0].item() - experience = Experience( - run_id=new_uid, - tokens=tokens, - prompt_length=prompt_length, - action_mask=gen_mask, - reward=all_rewards[i], - logprobs=log_probs, - info=all_infos[i], + experience = self.process_messages_to_experience( + memory, + final_reward, + {"session_id": session_id, "env_rounds": r, "env_done": 1 if done else 0}, ) experience_list.append(experience) + # Close the env to save cpu memory + env.close() return experience_list def run(self) -> List[Experience]: diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index f60412df13..1e82b2dfd3 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -4,6 +4,8 @@ from abc import ABC, abstractmethod from typing import List +import torch + from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.reward_fn import MathRewardFn @@ -30,6 +32,48 @@ def run(self) -> List[Experience]: """Run workflow and return a list of experiences.""" +class MultiTurnWorkflow(Workflow): + """ + The base workflow class for multi-turn tasks. + """ + + def __init__(self, model: ModelWrapper, **kwargs): + super().__init__(model) + + @abstractmethod + def run(self) -> List[Experience]: + """Run workflow and return a list of experiences.""" + + def process_messages_to_experience(self, messages, reward, info={}) -> Experience: + converted_experience = self.model.convert_messages_to_experience(messages) + + tokens = converted_experience.tokens + log_probs = converted_experience.logprobs + assert converted_experience.action_mask is not None + generation_mask = converted_experience.action_mask + log_probs = log_probs * generation_mask + + assert tokens.shape == log_probs.shape + # set prompt length to the first 1 in the gen_mask + prompt_length = torch.where(generation_mask == 1)[0][0].item() + + metrics = {} + for k, v in info.items(): + if isinstance(v, float) or isinstance(v, int): + metrics[k] = float(v) + + experience = Experience( + tokens=tokens, + prompt_length=prompt_length, + action_mask=generation_mask, + reward=reward, + logprobs=log_probs, + info=info, + metrics=metrics, + ) + return experience + + @WORKFLOWS.register_module("simple_workflow") class SimpleWorkflow(Workflow): """A workflow for simple single-round task."""