Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/config/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 4
dataset_path: 'scripts/data_prepare/alfworld_data'
default_workflow_type: 'alfworld_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'game_file'
model:
Expand Down
4 changes: 2 additions & 2 deletions scripts/config/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 96
dataset_path: 'countdown_dataset/oneshot-split'
default_workflow_type: 'math_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
default_reward_fn_type: 'countdown_reward'
format_config:
prompt_key: 'question'
Expand Down
5 changes: 3 additions & 2 deletions scripts/config/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
data:
# basic info
dataset_path: '/PATH/TO/DATASET/'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'question'
response_key: 'answer'
Expand Down Expand Up @@ -70,6 +70,7 @@ trainer:
algorithm_type: ppo
trainer_config_path: 'scripts/config/train_gsm8k.yaml'
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
eval_interval: 10
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
Expand Down
4 changes: 2 additions & 2 deletions scripts/config/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 4
dataset_path: 'scripts/data_prepare/webshop_data'
default_workflow_type: 'webshop_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'task_id'
model:
Expand Down
3 changes: 2 additions & 1 deletion tests/common/tmp/template_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ data:
dataset_path: ''
total_epoch: 1
batch_size: 1
split: train
train_split: 'train'
eval_split: ''
default_workflow_type: ''
default_reward_fn_type: ''
dataset_config: {}
Expand Down
31 changes: 17 additions & 14 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def explore(config: Config) -> None:
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
ref, _ = ray.wait([explorer.explore.remote()])
ray.get(ref)
ray.get([explorer.explore.remote()])
logger.info("Explore finished.")
except Exception as e:
logger.error(f"Explore failed: {e}")
Expand All @@ -34,8 +33,7 @@ def train(config: Config) -> None:
trainer = Trainer.remote(config)
try:
ray.get(trainer.prepare.remote())
ref, _ = ray.wait([trainer.train.remote(algo_type)])
ray.get(ref)
ray.get([trainer.train.remote(algo_type)])
logger.info("Train finished.")
except Exception as e:
logger.error(f"Train failed {e}.")
Expand Down Expand Up @@ -67,20 +65,21 @@ def both(config: Config) -> None:

if config.trainer.sft_warmup_iteration > 0:
for step in range(config.trainer.sft_warmup_iteration):
ray.get([trainer.train_step.remote(AlgorithmType.SFT)])
ray.get(trainer.train_step.remote(AlgorithmType.SFT))
logger.info(f"SFT warmup step {step} finished.")
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])

algo_type = config.trainer.algorithm_type
global_iter_num = 0
while True:
try:
explore_continue = explorer.explore_step.remote()
train_continue = trainer.train_step.remote(algo_type)
if not ray.get(explore_continue):
ref = explorer.explore_step.remote()
explore_continue, explore_iter_num = ray.get(ref)
ref = trainer.train_step.remote(algo_type)
train_continue, train_iter_num = ray.get(ref)
if not explore_continue:
logger.info("Explorer finished, stopping...")
break
if not ray.get(train_continue):
if not train_continue:
logger.info("Trainer finished, stopping...")
break
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
Expand All @@ -89,10 +88,14 @@ def both(config: Config) -> None:
logger.error(e)
logger.error("Training stopped due to exception.")
raise e
global_iter_num += 1
if global_iter_num % config.trainer.eval_interval == 0:
ray.wait([explorer.eval.remote()])
logger.info("Eval step finished.")
if (train_iter_num - 1) % config.trainer.eval_interval == 0:
try:
ray.get(explorer.eval.remote(train_iter_num))
logger.info("Evaluation finished.")
except Exception as e:
logger.error(e)
logger.error("Evaluation failed.")
raise e


def main() -> None:
Expand Down
10 changes: 10 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,16 @@ def check_and_update(self) -> None:
self.synchronizer.backend = self.explorer.backend
if self.synchronizer.sync_method == "online" and self.mode != "both":
raise ValueError("Online synchronization is only supported in both mode")

# check eval_interval
if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0:
self.trainer.eval_interval = (
self.trainer.eval_interval // self.synchronizer.sync_iteration_interval
) * self.synchronizer.sync_iteration_interval
print(
f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}."
)

# check monitor
if not self.monitor.cache_root_dir:
# create a cache dir in <checkpoint_path>/.cache
Expand Down
9 changes: 5 additions & 4 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def run(self) -> List[Experience]:
else:
messages = [{"role": "user", "content": self.task_desc}]
logger.debug("start chat")
responses = self.model.chat(messages, n=self.repeat_times)
n = 1 if self.is_eval else self.repeat_times
responses = self.model.chat(messages, n=n)
for response in responses:
reward = self.reward_fn( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
Expand All @@ -69,9 +70,9 @@ def run(self) -> List[Experience]:
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
if isinstance(reward, dict):
if response.info is None:
response.info = {}
response.info.update(reward)
if response.metrics is None:
response.metrics = {}
response.metrics.update(reward)
reward = sum(reward.values())
response.reward = reward
return responses
Expand Down
24 changes: 13 additions & 11 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import time
from collections import defaultdict
from typing import List, Optional
from typing import List, Optional, Tuple

import ray
import torch
Expand Down Expand Up @@ -149,16 +149,20 @@ def get_weight(self, name: str) -> torch.Tensor:

def explore(self) -> None:
"""Explore the entire dataset."""
while self.explore_step():
explore_status, _ = self.explore_step()
while explore_status:
self.sync_weight()
self.logger.info("Explorer finished.")

def explore_step(self) -> bool:
def explore_step(self) -> Tuple[bool, int]:
"""Explore for one step.

Different from `explore()` which consumes all tasks in the task set,
`explore_step()` only consume `sync_iteration_interval * batch_size`
number of tasks.
explore_status:
explore_status: whether there are more tasks to explore.
explore_iter_num: the number of explore iterations
"""
if self.task_iter is None:
self.task_iter = iter(self.taskset)
Expand All @@ -175,7 +179,7 @@ def explore_step(self) -> bool:
self.runner_pool.run_tasks(tasks)
except StopIteration:
self.logger.warning("No more tasks in the task set. Stop exploring.")
return False
return False, self.iteration

# wait for all tasks of this step to finish
while self.runner_pool.has_next():
Expand All @@ -190,7 +194,7 @@ def explore_step(self) -> bool:
self.runner_pool.run_tasks(next(self.task_iter)) # type: ignore
except StopIteration:
self.logger.warning("No more tasks in the task set. Stop exploring.")
return False
return False, self.iteration
else:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)
Expand All @@ -208,11 +212,11 @@ def explore_step(self) -> bool:
)

self.logger.info("Explore step finished.")
return True
return True, self.iteration

def eval(self) -> bool:
def eval(self, step) -> bool:
"""Evaluation on all evaluation data samples."""
self.logger.info("\n\nEvaluation started.\n\n")
self.logger.info("Evaluation started.")
st = time.time()
all_metrics = defaultdict(list)

Expand All @@ -231,11 +235,9 @@ def eval(self) -> bool:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)

self.logger.info("Evaluation finished.")

log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore
log_metrics["eval/total_time"] = time.time() - st
self.monitor.log(log_metrics, step=self.iteration) # type: ignore
self.monitor.log(log_metrics, step=step) # type: ignore
return True

def sync_weight(self) -> None:
Expand Down
26 changes: 16 additions & 10 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from abc import ABC, abstractmethod
from typing import Tuple

import ray

Expand Down Expand Up @@ -45,18 +46,23 @@ def prepare(self) -> None:
def train(self, algo_type: AlgorithmType = AlgorithmType.PPO):
"""Train the model."""
while True:
if not self.train_iteration(algo_type):
train_status, _ = self.train_iteration(algo_type)
if not train_status:
break

def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> bool:
"""Train one step. Each step contains `sync_iteration_interval` iteration."""
def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
"""Train one step. Each step contains `sync_iteration_interval` iteration.
Returns:
train_status: Whether to continue training.
train_iter_num: The number of training iterations"""
for _ in range(self.config.synchronizer.sync_iteration_interval):
if not self.train_iteration(algo_type):
return False
train_status, train_iter_num = self.train_iteration(algo_type)
if not train_status:
return False, train_iter_num
self.logger.info("Trainer finished.")
return True
return True, train_iter_num

def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> bool:
def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
"""Train one iteration.

Args:
Expand Down Expand Up @@ -108,15 +114,15 @@ def prepare(self) -> None:
"""Do some preparation before training started."""

@abstractmethod
def train_rft_iteration(self, experiences) -> bool:
def train_rft_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the RFT data."""

@abstractmethod
def train_sft_iteration(self, experiences) -> bool:
def train_sft_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the SFT data."""

@abstractmethod
def train_dpo_iteration(self, experiences) -> bool:
def train_dpo_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the DPO data."""

@abstractmethod
Expand Down
15 changes: 8 additions & 7 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Modified from verl/trainer/ppo/ray_trainer.py
"""
import os
from typing import Tuple

import pandas as pd
import ray
Expand Down Expand Up @@ -182,7 +183,7 @@ def _create_dataloader(self):
# else:
self.total_training_steps = float("inf")

def train_dpo_iteration(self, experiences: Experiences) -> bool:
def train_dpo_iteration(self, experiences: Experiences) -> Tuple[bool, int]:
metrics = {}
timing_raw = {}

Expand Down Expand Up @@ -243,9 +244,9 @@ def train_dpo_iteration(self, experiences: Experiences) -> bool:
self._save_checkpoint()

self.global_steps += 1
return True
return True, self.global_steps

def train_sft_iteration(self, experiences: Experiences) -> bool:
def train_sft_iteration(self, experiences: Experiences) -> Tuple[bool, int]:
metrics = {}
timing_raw = {}

Expand Down Expand Up @@ -309,9 +310,9 @@ def train_sft_iteration(self, experiences: Experiences) -> bool:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
self.global_steps += 1
return True
return True, self.global_steps

def train_rft_iteration(self, experiences: Experiences) -> bool:
def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]:
metrics = {}
timing_raw = {}

Expand Down Expand Up @@ -456,10 +457,10 @@ def train_rft_iteration(self, experiences: Experiences) -> bool:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# stop training
return False
return False, self.global_steps
else:
# continue
return True
return True, self.global_steps

def _log_single_experience(
self, experiences: Experiences, idx: int, skip_special_tokens: bool
Expand Down