Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ global_config:
total_epochs: 1
batch_size: 96
eval_interval: 1000
eval_on_latest_ckp: true
```

- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
- `global_config.eval_on_latest_ckp`: Whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Only valid in `bench` mode. Default is `true`.


## Monitor
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray[default]==2.43.0",
"ray[default]>=2.45.0",
"vllm>=0.8.5",
"tensordict==0.6.2",
"wandb",
Expand Down
8 changes: 5 additions & 3 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_unittest_dataset_config(
dataset_name: str = "countdown", split: str = "train"
) -> StorageConfig:
"""Countdown sample dataset for 8 steps"""
if dataset_name == "countdown":
if dataset_name == "countdown" or dataset_name == "copy_countdown":
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"),
Expand Down Expand Up @@ -86,10 +86,12 @@ def metric_exist(self, metric_name: str) -> bool:
return metric_name in self._metrics

def metric_max_step(self, metric_name: str) -> int:
return max(self.metric_steps(metric_name))

def metric_steps(self, metric_name: str) -> List[int]:
if not self.metric_exist(metric_name):
raise ValueError(f"Metric '{metric_name}' does not exist.")
steps = list(self._metrics[metric_name].keys())
return max(steps)
return list(self._metrics[metric_name].keys())

def metric_list(self, metric_prefix: str) -> List[str]:
return [name for name in self._metrics if name.startswith(metric_prefix)]
Expand Down
50 changes: 41 additions & 9 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
get_template_config,
get_unittest_dataset_config,
)
from trinity.cli.launcher import both
from trinity.cli.launcher import bench, both
from trinity.common.constants import MonitorType, SyncMethod


Expand All @@ -27,9 +27,11 @@ def setUp(self):
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.repeat_times = 3
self.config.explorer.use_v1 = False
self.config.monitor.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.model.checkpoint_path = os.path.join(
get_checkpoint_path(), f"train-{datetime.now().strftime('%Y%m%d%H%M%S')}"
get_checkpoint_path(), f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
)
self.config.synchronizer.sync_interval = 2
self.config.synchronizer.sync_method = SyncMethod.NCCL
Expand All @@ -42,15 +44,20 @@ def test_trainer(self):

class TestTrainerCountdown(BaseTrainerCase):
def test_trainer(self):
"""Test the trainer."""
"""Test the both and bench mode."""
# test both mode
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("countdown", "test")
)
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("copy_countdown", "test")
)
self.config.trainer.save_interval = 4
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.save_freq = 8
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 2
both(self.config)
# check tensorboard
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
Expand All @@ -64,16 +71,41 @@ def test_trainer(self):
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint
from trinity.common.models.utils import get_checkpoint_dir_with_step_num

checkpoint_dir = get_checkpoint_dir_with_step_num(
checkpoint_step_4 = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.model.checkpoint_path,
trainer_type=self.config.trainer.trainer_type,
step_num=4,
)
checkpoint_step_8 = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.model.checkpoint_path,
trainer_type=self.config.trainer.trainer_type,
step_num=None,
step_num=8,
)
self.assertTrue(os.path.exists(checkpoint_dir))
self.assertTrue(checkpoint_dir.endswith("step_8"))
self.assertTrue(os.path.exists(checkpoint_step_4))
self.assertTrue(os.path.exists(checkpoint_step_8))

ray.init(ignore_reinit_error=True)
# test bench mode
self.config.mode = "bench"
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
self.config.global_config.eval_on_latest_ckp = False
self.config.check_and_update()
bench(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
countdown_metrics = parser.metric_list("eval/countdown")
copy_countdown_metrics = parser.metric_list("eval/copy_countdown")
self.assertTrue(len(countdown_metrics) > 0)
self.assertTrue(len(copy_countdown_metrics) > 0)
countdown_metric_steps = parser.metric_steps(countdown_metrics[0])
countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0])
self.assertEqual(2, len(countdown_metric_steps))
self.assertEqual(2, len(countdown_copy_metric_steps))
self.assertTrue(4 in countdown_metric_steps)
self.assertTrue(8 in countdown_metric_steps)

def tearDown(self):
# remove dir only when the test passed
Expand Down
14 changes: 9 additions & 5 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ def bench(config: Config) -> None:
explorer = Explorer.remote(config)
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
_, step = ray.get(explorer.eval.remote())
logger.info("Evaluation finished.")
ray.get(explorer.flush_log.remote(step=step))
ray.get(explorer.benchmark.remote())
logger.info("Benchmark finished.")
ray.get(explorer.shutdown.remote())
except Exception as e:
logger.error(f"Evaluation failed: {e}")
logger.error(f"Benchmark failed: {e}")
raise e


Expand All @@ -35,6 +34,7 @@ def explore(config: Config) -> None:
ray.get(explorer.sync_weight.remote())
ray.get(explorer.explore.remote())
logger.info("Explore finished.")
ray.get(explorer.shutdown.remote())
except Exception as e:
logger.error(f"Explore failed: {e}")
raise e
Expand All @@ -60,6 +60,7 @@ def train(config: Config) -> None:
try:
ray.get(trainer.train.remote(algo_type))
logger.info("Train finished.")
ray.get(trainer.shutdown.remote())
except Exception as e:
logger.error(f"Train failed {e}.")
raise e
Expand Down Expand Up @@ -133,6 +134,9 @@ def both(config: Config) -> None:
ray.get(explorer.flush_log.remote(step=explore_step_num))
ray.get(trainer.flush_log.remote(step=train_step_num))

ray.get(explorer.shutdown.remote())
ray.get(trainer.shutdown.remote())


def activate_data_module(data_workflow_url: str, config_path: str):
"""Check whether to activate data module and preprocess datasets."""
Expand Down
16 changes: 10 additions & 6 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class GlobalConfig:
total_epochs: int = 1
batch_size: int = 1
eval_interval: int = 100
eval_on_latest_ckp: bool = True


@dataclass
Expand Down Expand Up @@ -299,7 +300,8 @@ def _check_interval(self) -> None:

# check eval_interval
if (
self.trainer.algorithm_type != AlgorithmType.DPO
self.mode != "bench"
and self.trainer.algorithm_type != AlgorithmType.DPO
and self.global_config.eval_interval % self.synchronizer.sync_interval != 0
):
self.global_config.eval_interval = (
Expand All @@ -311,12 +313,13 @@ def _check_interval(self) -> None:

# check save_interval
if (
self.trainer.algorithm_type != AlgorithmType.DPO
self.mode != "bench"
and self.trainer.algorithm_type != AlgorithmType.DPO
and self.synchronizer.sync_method == SyncMethod.CHECKPOINT
):
if self.trainer.save_interval != self.synchronizer.sync_interval:
logger.warning(
f"When `trainer.algorithm_type != DPO` and `synchronizer.sync_method == checkpoint`, "
f"When `trainer.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, "
f"`trainer.save_interval` will be set to "
f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`."
)
Expand Down Expand Up @@ -356,7 +359,7 @@ def _check_buffer(self) -> None: # noqa: C901
logger.info(
f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}"
)
else: # TODO: to be check
elif self.mode == "train": # TODO: to be check
if self.trainer.algorithm_type.is_dpo():
if (
self.buffer.trainer_input.experience_buffer is None
Expand All @@ -365,7 +368,8 @@ def _check_buffer(self) -> None: # noqa: C901
raise ValueError(
"`buffer.trainer_input.experience_buffer.path` is required when `trainer.algorithm_type == AlgorithmType.DPO`"
)
self.buffer.trainer_input.experience_buffer.algorithm_type = self.trainer.algorithm_type
if self.mode in ["both", "train"]:
self.buffer.trainer_input.experience_buffer.algorithm_type = self.trainer.algorithm_type

# set buffer.explorer_output
if self.buffer.explorer_output is None:
Expand Down Expand Up @@ -418,7 +422,7 @@ def check_and_update(self) -> None: # noqa: C901
)
self.synchronizer.backend = self.explorer.backend
if self.mode == "bench" and self.synchronizer.sync_method != SyncMethod.CHECKPOINT:
self.synchronizer.sync_method = "checkpoint"
self.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
"Bench mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
Expand Down
35 changes: 31 additions & 4 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def __init__(self, config: Config):
self.step_num = explorer_meta.get("latest_iteration", 0)
self.config = config
self.models = create_rollout_models(config)
self.experience_buffer = get_buffer_writer(
self.config.buffer.explorer_output, # type: ignore
self.config.buffer,
)
if self.config.mode != "bench":
self.experience_buffer = get_buffer_writer(
self.config.buffer.explorer_output, # type: ignore
self.config.buffer,
)
self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0)
self.taskset = get_buffer_reader(
self.config.buffer.explorer_input.taskset, self.config.buffer
Expand Down Expand Up @@ -261,6 +262,29 @@ def wait():
self.monitor.log(log_metrics, step=self.step_num) # type: ignore
return True, self.step_num

def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.global_config.eval_on_latest_ckp:
self._checkpoint_weights_update()
self.eval()
return True

# benchmark on all checkoints
all_ckp_steps = sorted(
[
int(ckp.split("global_step_")[-1])
for ckp in os.listdir(self.config.model.checkpoint_path)
if os.path.isdir(os.path.join(self.config.model.checkpoint_path, ckp))
and ckp.startswith("global_step_")
]
)
for step_num in all_ckp_steps:
self.step_num = step_num
self._checkpoint_weights_update(step_num=step_num)
self.eval()
return True

def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
Expand All @@ -272,3 +296,6 @@ def sync_weight(self) -> None:
def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.monitor.log({}, step=step, commit=True)

def shutdown(self) -> None:
self.monitor.close()
11 changes: 10 additions & 1 deletion trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import os
from abc import ABC, abstractmethod
from typing import Tuple

Expand Down Expand Up @@ -59,7 +60,7 @@ def train_one_period(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tupl
train_status, train_step_num = self.train_step(algo_type)
if not train_status:
return False, train_step_num
self.logger.info(f"Trainer steps {train_step_num} finished.")
self.logger.info(f"Train step {train_step_num} finished.")
return True, train_step_num

def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
Expand Down Expand Up @@ -119,6 +120,14 @@ def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.engine.logger.log({}, step=step, commit=True)

def shutdown(self) -> None:
# if checkpoint not saved, save the last checkpoint
step_num = self.engine.global_steps - 1
path = os.path.join(self.config.model.checkpoint_path, f"global_step_{step_num}")
if not os.path.isdir(path) or len(os.listdir(path)) == 0:
self.engine.save_checkpoint()
self.engine.logger.close()


class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
Expand Down
9 changes: 9 additions & 0 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step, commit=commit)

def close(self) -> None:
self.logger.close()


class TensorboardLogger:
def __init__(self, project: str, name: str, role: str, config: Any = None) -> None:
Expand All @@ -70,6 +73,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
for key in data:
self.logger.add_scalar(key, data[key], step)

def close(self) -> None:
self.logger.close()

def __del__(self) -> None:
self.logger.close()

Expand All @@ -95,5 +101,8 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
self.logger.log(data, step=step, commit=commit)
self.console_logger.info(f"Step {step}: {data}")

def close(self) -> None:
self.logger.finish()

def __del__(self) -> None:
self.logger.finish()