Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/config/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 4
dataset_path: 'scripts/data_prepare/alfworld_data'
default_workflow_type: 'alfworld_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'game_file'
model:
Expand Down
4 changes: 2 additions & 2 deletions scripts/config/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 96
dataset_path: 'countdown_dataset/oneshot-split'
default_workflow_type: 'math_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
default_reward_fn_type: 'countdown_reward'
format_config:
prompt_key: 'question'
Expand Down
5 changes: 3 additions & 2 deletions scripts/config/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
data:
# basic info
dataset_path: '/PATH/TO/DATASET/'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'question'
response_key: 'answer'
Expand Down Expand Up @@ -70,6 +70,7 @@ trainer:
algorithm_type: ppo
trainer_config_path: 'scripts/config/train_gsm8k.yaml'
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
eval_interval: 50
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
Expand Down
4 changes: 2 additions & 2 deletions scripts/config/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 4
dataset_path: 'scripts/data_prepare/webshop_data'
default_workflow_type: 'webshop_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'task_id'
model:
Expand Down
3 changes: 2 additions & 1 deletion tests/common/tmp/template_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ data:
dataset_path: ''
total_epoch: 1
batch_size: 1
split: train
train_split: 'train'
eval_split: ''
default_workflow_type: ''
default_reward_fn_type: ''
dataset_config: {}
Expand Down
31 changes: 17 additions & 14 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def explore(config: Config) -> None:
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
ref, _ = ray.wait([explorer.explore.remote()])
ray.get(ref)
ray.get(explorer.explore.remote())
logger.info("Explore finished.")
except Exception as e:
logger.error(f"Explore failed: {e}")
Expand All @@ -34,8 +33,7 @@ def train(config: Config) -> None:
trainer = Trainer.remote(config)
try:
ray.get(trainer.prepare.remote())
ref, _ = ray.wait([trainer.train.remote(algo_type)])
ray.get(ref)
ray.get(trainer.train.remote(algo_type))
logger.info("Train finished.")
except Exception as e:
logger.error(f"Train failed {e}.")
Expand Down Expand Up @@ -67,20 +65,21 @@ def both(config: Config) -> None:

if config.trainer.sft_warmup_iteration > 0:
for step in range(config.trainer.sft_warmup_iteration):
ray.get([trainer.train_step.remote(AlgorithmType.SFT)])
ray.get(trainer.train_step.remote(AlgorithmType.SFT))
logger.info(f"SFT warmup step {step} finished.")
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])

algo_type = config.trainer.algorithm_type
global_iter_num = 0
while True:
try:
explore_continue = explorer.explore_step.remote()
train_continue = trainer.train_step.remote(algo_type)
if not ray.get(explore_continue):
ref_explore = explorer.explore_step.remote()
ref_train = trainer.train_step.remote(algo_type)
explore_continue, _ = ray.get(ref_explore)
train_continue, train_iter_num = ray.get(ref_train)
if not explore_continue:
logger.info("Explorer finished, stopping...")
break
if not ray.get(train_continue):
if not train_continue:
logger.info("Trainer finished, stopping...")
break
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
Expand All @@ -89,10 +88,14 @@ def both(config: Config) -> None:
logger.error(e)
logger.error("Training stopped due to exception.")
raise e
global_iter_num += 1
if global_iter_num % config.trainer.eval_interval == 0:
ray.wait([explorer.eval.remote()])
logger.info("Eval step finished.")
if (train_iter_num - 1) % config.trainer.eval_interval == 0:
try:
ray.get(explorer.eval.remote(train_iter_num))
logger.info("Evaluation finished.")
except Exception as e:
logger.error(e)
logger.error("Evaluation failed.")
raise e


def main() -> None:
Expand Down
10 changes: 10 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,16 @@ def check_and_update(self) -> None:
self.synchronizer.backend = self.explorer.backend
if self.synchronizer.sync_method == "online" and self.mode != "both":
raise ValueError("Online synchronization is only supported in both mode")

# check eval_interval
if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0:
self.trainer.eval_interval = (
self.trainer.eval_interval // self.synchronizer.sync_iteration_interval
) * self.synchronizer.sync_iteration_interval
print(
f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}."
)

# check monitor
if not self.monitor.cache_root_dir:
# create a cache dir in <checkpoint_path>/.cache
Expand Down
9 changes: 5 additions & 4 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def run(self) -> List[Experience]:
else:
messages = [{"role": "user", "content": self.task_desc}]
logger.debug("start chat")
responses = self.model.chat(messages, n=self.repeat_times)
n = 1 if self.is_eval else self.repeat_times
responses = self.model.chat(messages, n=n)
for response in responses:
reward = self.reward_fn( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
Expand All @@ -69,9 +70,9 @@ def run(self) -> List[Experience]:
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
if isinstance(reward, dict):
if response.info is None:
response.info = {}
response.info.update(reward)
if response.metrics is None:
response.metrics = {}
response.metrics.update(reward)
reward = sum(reward.values())
response.reward = reward
return responses
Expand Down
24 changes: 13 additions & 11 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import time
from collections import defaultdict
from typing import List, Optional
from typing import List, Optional, Tuple

import ray
import torch
Expand Down Expand Up @@ -149,16 +149,20 @@ def get_weight(self, name: str) -> torch.Tensor:

def explore(self) -> None:
"""Explore the entire dataset."""
while self.explore_step():
explore_status, _ = self.explore_step()
while explore_status:
self.sync_weight()
self.logger.info("Explorer finished.")

def explore_step(self) -> bool:
def explore_step(self) -> Tuple[bool, int]:
"""Explore for one step.

Different from `explore()` which consumes all tasks in the task set,
`explore_step()` only consume `sync_iteration_interval * batch_size`
number of tasks.
explore_status:
explore_status: whether there are more tasks to explore.
explore_iter_num: the number of explore iterations
"""
if self.task_iter is None:
self.task_iter = iter(self.taskset)
Expand All @@ -175,7 +179,7 @@ def explore_step(self) -> bool:
self.runner_pool.run_tasks(tasks)
except StopIteration:
self.logger.warning("No more tasks in the task set. Stop exploring.")
return False
return False, self.iteration

# wait for all tasks of this step to finish
while self.runner_pool.has_next():
Expand All @@ -190,7 +194,7 @@ def explore_step(self) -> bool:
self.runner_pool.run_tasks(next(self.task_iter)) # type: ignore
except StopIteration:
self.logger.warning("No more tasks in the task set. Stop exploring.")
return False
return False, self.iteration
else:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)
Expand All @@ -208,11 +212,11 @@ def explore_step(self) -> bool:
)

self.logger.info("Explore step finished.")
return True
return True, self.iteration

def eval(self) -> bool:
def eval(self, step) -> bool:
"""Evaluation on all evaluation data samples."""
self.logger.info("\n\nEvaluation started.\n\n")
self.logger.info("Evaluation started.")
st = time.time()
all_metrics = defaultdict(list)

Expand All @@ -231,11 +235,9 @@ def eval(self) -> bool:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)

self.logger.info("Evaluation finished.")

log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore
log_metrics["eval/total_time"] = time.time() - st
self.monitor.log(log_metrics, step=self.iteration) # type: ignore
self.monitor.log(log_metrics, step=step) # type: ignore
return True

def sync_weight(self) -> None:
Expand Down
26 changes: 16 additions & 10 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from abc import ABC, abstractmethod
from typing import Tuple

import ray

Expand Down Expand Up @@ -45,18 +46,23 @@ def prepare(self) -> None:
def train(self, algo_type: AlgorithmType = AlgorithmType.PPO):
"""Train the model."""
while True:
if not self.train_iteration(algo_type):
train_status, _ = self.train_iteration(algo_type)
if not train_status:
break

def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> bool:
"""Train one step. Each step contains `sync_iteration_interval` iteration."""
def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
"""Train one step. Each step contains `sync_iteration_interval` iteration.
Returns:
train_status: Whether to continue training.
train_iter_num: The number of training iterations"""
for _ in range(self.config.synchronizer.sync_iteration_interval):
if not self.train_iteration(algo_type):
return False
train_status, train_iter_num = self.train_iteration(algo_type)
if not train_status:
return False, train_iter_num
self.logger.info("Trainer finished.")
return True
return True, train_iter_num

def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> bool:
def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
"""Train one iteration.

Args:
Expand Down Expand Up @@ -108,15 +114,15 @@ def prepare(self) -> None:
"""Do some preparation before training started."""

@abstractmethod
def train_rft_iteration(self, experiences) -> bool:
def train_rft_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the RFT data."""

@abstractmethod
def train_sft_iteration(self, experiences) -> bool:
def train_sft_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the SFT data."""

@abstractmethod
def train_dpo_iteration(self, experiences) -> bool:
def train_dpo_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the DPO data."""

@abstractmethod
Expand Down
15 changes: 8 additions & 7 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Modified from verl/trainer/ppo/ray_trainer.py
"""
import os
from typing import Tuple

import pandas as pd
import ray
Expand Down Expand Up @@ -182,7 +183,7 @@ def _create_dataloader(self):
# else:
self.total_training_steps = float("inf")

def train_dpo_iteration(self, experiences: Experiences) -> bool:
def train_dpo_iteration(self, experiences: Experiences) -> Tuple[bool, int]:
metrics = {}
timing_raw = {}

Expand Down Expand Up @@ -243,9 +244,9 @@ def train_dpo_iteration(self, experiences: Experiences) -> bool:
self._save_checkpoint()

self.global_steps += 1
return True
return True, self.global_steps

def train_sft_iteration(self, experiences: Experiences) -> bool:
def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]:
metrics = {}
timing_raw = {}

Expand Down Expand Up @@ -309,9 +310,9 @@ def train_sft_iteration(self, experiences: Experiences) -> bool:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
self.global_steps += 1
return True
return True, self.global_steps

def train_rft_iteration(self, experiences: Experiences) -> bool:
def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]:
metrics = {}
timing_raw = {}

Expand Down Expand Up @@ -456,10 +457,10 @@ def train_rft_iteration(self, experiences: Experiences) -> bool:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# stop training
return False
return False, self.global_steps
else:
# continue
return True
return True, self.global_steps

def _log_single_experience(
self, experiences: Experiences, idx: int, skip_special_tokens: bool
Expand Down