Skip to content

Commit 11a95be

Browse files
authored
Add benchmark mode (#39)
1 parent fd69ba4 commit 11a95be

File tree

9 files changed

+118
-29
lines changed

9 files changed

+118
-29
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ global_config:
1010
total_epochs: 1
1111
batch_size: 96
1212
eval_interval: 1000
13+
eval_on_latest_ckp: true
1314
```
1415
1516
- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
1617
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
1718
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
1819
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
20+
- `global_config.eval_on_latest_ckp`: Whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Only valid in `bench` mode. Default is `true`.
1921

2022

2123
## Monitor

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ classifiers = [
2020
requires-python = ">=3.10"
2121
dependencies = [
2222
"verl==0.3.0.post1",
23-
"ray[default]==2.43.0",
23+
"ray[default]>=2.45.0",
2424
"vllm>=0.8.5",
2525
"tensordict==0.6.2",
2626
"wandb",

tests/tools.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_unittest_dataset_config(
3636
dataset_name: str = "countdown", split: str = "train"
3737
) -> StorageConfig:
3838
"""Countdown sample dataset for 8 steps"""
39-
if dataset_name == "countdown":
39+
if dataset_name == "countdown" or dataset_name == "copy_countdown":
4040
return StorageConfig(
4141
name=dataset_name,
4242
path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"),
@@ -86,10 +86,12 @@ def metric_exist(self, metric_name: str) -> bool:
8686
return metric_name in self._metrics
8787

8888
def metric_max_step(self, metric_name: str) -> int:
89+
return max(self.metric_steps(metric_name))
90+
91+
def metric_steps(self, metric_name: str) -> List[int]:
8992
if not self.metric_exist(metric_name):
9093
raise ValueError(f"Metric '{metric_name}' does not exist.")
91-
steps = list(self._metrics[metric_name].keys())
92-
return max(steps)
94+
return list(self._metrics[metric_name].keys())
9395

9496
def metric_list(self, metric_prefix: str) -> List[str]:
9597
return [name for name in self._metrics if name.startswith(metric_prefix)]

tests/trainer/trainer_test.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_template_config,
1515
get_unittest_dataset_config,
1616
)
17-
from trinity.cli.launcher import both
17+
from trinity.cli.launcher import bench, both
1818
from trinity.common.constants import MonitorType, SyncMethod
1919

2020

@@ -27,9 +27,11 @@ def setUp(self):
2727
self.config.model.model_path = get_model_path()
2828
self.config.explorer.engine_type = "vllm_async"
2929
self.config.explorer.repeat_times = 3
30+
self.config.explorer.use_v1 = False
31+
self.config.monitor.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
3032
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
3133
self.config.model.checkpoint_path = os.path.join(
32-
get_checkpoint_path(), f"train-{datetime.now().strftime('%Y%m%d%H%M%S')}"
34+
get_checkpoint_path(), f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
3335
)
3436
self.config.synchronizer.sync_interval = 2
3537
self.config.synchronizer.sync_method = SyncMethod.NCCL
@@ -42,15 +44,20 @@ def test_trainer(self):
4244

4345
class TestTrainerCountdown(BaseTrainerCase):
4446
def test_trainer(self):
45-
"""Test the trainer."""
47+
"""Test the both and bench mode."""
48+
# test both mode
4649
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
4750
self.config.buffer.explorer_input.eval_tasksets.append(
4851
get_unittest_dataset_config("countdown", "test")
4952
)
53+
self.config.buffer.explorer_input.eval_tasksets.append(
54+
get_unittest_dataset_config("copy_countdown", "test")
55+
)
56+
self.config.trainer.save_interval = 4
5057
self.config.check_and_update()
51-
self.config.trainer.trainer_config.trainer.save_freq = 8
58+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
59+
self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 2
5260
both(self.config)
53-
# check tensorboard
5461
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
5562
rollout_metrics = parser.metric_list("rollout")
5663
self.assertTrue(len(rollout_metrics) > 0)
@@ -64,16 +71,41 @@ def test_trainer(self):
6471
response_metrics = parser.metric_list("response_length")
6572
self.assertTrue(len(response_metrics) > 0)
6673
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
74+
ray.shutdown(_exiting_interpreter=True)
6775
# check checkpoint
6876
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
6977

70-
checkpoint_dir = get_checkpoint_dir_with_step_num(
78+
checkpoint_step_4 = get_checkpoint_dir_with_step_num(
79+
checkpoint_root_path=self.config.model.checkpoint_path,
80+
trainer_type=self.config.trainer.trainer_type,
81+
step_num=4,
82+
)
83+
checkpoint_step_8 = get_checkpoint_dir_with_step_num(
7184
checkpoint_root_path=self.config.model.checkpoint_path,
7285
trainer_type=self.config.trainer.trainer_type,
73-
step_num=None,
86+
step_num=8,
7487
)
75-
self.assertTrue(os.path.exists(checkpoint_dir))
76-
self.assertTrue(checkpoint_dir.endswith("step_8"))
88+
self.assertTrue(os.path.exists(checkpoint_step_4))
89+
self.assertTrue(os.path.exists(checkpoint_step_8))
90+
91+
ray.init(ignore_reinit_error=True)
92+
# test bench mode
93+
self.config.mode = "bench"
94+
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
95+
self.config.global_config.eval_on_latest_ckp = False
96+
self.config.check_and_update()
97+
bench(self.config)
98+
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
99+
countdown_metrics = parser.metric_list("eval/countdown")
100+
copy_countdown_metrics = parser.metric_list("eval/copy_countdown")
101+
self.assertTrue(len(countdown_metrics) > 0)
102+
self.assertTrue(len(copy_countdown_metrics) > 0)
103+
countdown_metric_steps = parser.metric_steps(countdown_metrics[0])
104+
countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0])
105+
self.assertEqual(2, len(countdown_metric_steps))
106+
self.assertEqual(2, len(countdown_copy_metric_steps))
107+
self.assertTrue(4 in countdown_metric_steps)
108+
self.assertTrue(8 in countdown_metric_steps)
77109

78110
def tearDown(self):
79111
# remove dir only when the test passed

trinity/cli/launcher.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@ def bench(config: Config) -> None:
1818
explorer = Explorer.remote(config)
1919
try:
2020
ray.get(explorer.prepare.remote())
21-
ray.get(explorer.sync_weight.remote())
22-
_, step = ray.get(explorer.eval.remote())
23-
logger.info("Evaluation finished.")
24-
ray.get(explorer.flush_log.remote(step=step))
21+
ray.get(explorer.benchmark.remote())
22+
logger.info("Benchmark finished.")
23+
ray.get(explorer.shutdown.remote())
2524
except Exception as e:
26-
logger.error(f"Evaluation failed: {e}")
25+
logger.error(f"Benchmark failed: {e}")
2726
raise e
2827

2928

@@ -35,6 +34,7 @@ def explore(config: Config) -> None:
3534
ray.get(explorer.sync_weight.remote())
3635
ray.get(explorer.explore.remote())
3736
logger.info("Explore finished.")
37+
ray.get(explorer.shutdown.remote())
3838
except Exception as e:
3939
logger.error(f"Explore failed: {e}")
4040
raise e
@@ -60,6 +60,7 @@ def train(config: Config) -> None:
6060
try:
6161
ray.get(trainer.train.remote(algo_type))
6262
logger.info("Train finished.")
63+
ray.get(trainer.shutdown.remote())
6364
except Exception as e:
6465
logger.error(f"Train failed {e}.")
6566
raise e
@@ -133,6 +134,9 @@ def both(config: Config) -> None:
133134
ray.get(explorer.flush_log.remote(step=explore_step_num))
134135
ray.get(trainer.flush_log.remote(step=train_step_num))
135136

137+
ray.get(explorer.shutdown.remote())
138+
ray.get(trainer.shutdown.remote())
139+
136140

137141
def activate_data_module(data_workflow_url: str, config_path: str):
138142
"""Check whether to activate data module and preprocess datasets."""

trinity/common/config.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class GlobalConfig:
107107
total_epochs: int = 1
108108
batch_size: int = 1
109109
eval_interval: int = 100
110+
eval_on_latest_ckp: bool = True
110111

111112

112113
@dataclass
@@ -299,7 +300,8 @@ def _check_interval(self) -> None:
299300

300301
# check eval_interval
301302
if (
302-
self.trainer.algorithm_type != AlgorithmType.DPO
303+
self.mode != "bench"
304+
and self.trainer.algorithm_type != AlgorithmType.DPO
303305
and self.global_config.eval_interval % self.synchronizer.sync_interval != 0
304306
):
305307
self.global_config.eval_interval = (
@@ -311,12 +313,13 @@ def _check_interval(self) -> None:
311313

312314
# check save_interval
313315
if (
314-
self.trainer.algorithm_type != AlgorithmType.DPO
316+
self.mode != "bench"
317+
and self.trainer.algorithm_type != AlgorithmType.DPO
315318
and self.synchronizer.sync_method == SyncMethod.CHECKPOINT
316319
):
317320
if self.trainer.save_interval != self.synchronizer.sync_interval:
318321
logger.warning(
319-
f"When `trainer.algorithm_type != DPO` and `synchronizer.sync_method == checkpoint`, "
322+
f"When `trainer.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, "
320323
f"`trainer.save_interval` will be set to "
321324
f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`."
322325
)
@@ -356,7 +359,7 @@ def _check_buffer(self) -> None: # noqa: C901
356359
logger.info(
357360
f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}"
358361
)
359-
else: # TODO: to be check
362+
elif self.mode == "train": # TODO: to be check
360363
if self.trainer.algorithm_type.is_dpo():
361364
if (
362365
self.buffer.trainer_input.experience_buffer is None
@@ -365,7 +368,8 @@ def _check_buffer(self) -> None: # noqa: C901
365368
raise ValueError(
366369
"`buffer.trainer_input.experience_buffer.path` is required when `trainer.algorithm_type == AlgorithmType.DPO`"
367370
)
368-
self.buffer.trainer_input.experience_buffer.algorithm_type = self.trainer.algorithm_type
371+
if self.mode in ["both", "train"]:
372+
self.buffer.trainer_input.experience_buffer.algorithm_type = self.trainer.algorithm_type
369373

370374
# set buffer.explorer_output
371375
if self.buffer.explorer_output is None:
@@ -418,7 +422,7 @@ def check_and_update(self) -> None: # noqa: C901
418422
)
419423
self.synchronizer.backend = self.explorer.backend
420424
if self.mode == "bench" and self.synchronizer.sync_method != SyncMethod.CHECKPOINT:
421-
self.synchronizer.sync_method = "checkpoint"
425+
self.synchronizer.sync_method = SyncMethod.CHECKPOINT
422426
logger.warning(
423427
"Bench mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
424428
)

trinity/explorer/explorer.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ def __init__(self, config: Config):
3434
self.step_num = explorer_meta.get("latest_iteration", 0)
3535
self.config = config
3636
self.models = create_rollout_models(config)
37-
self.experience_buffer = get_buffer_writer(
38-
self.config.buffer.explorer_output, # type: ignore
39-
self.config.buffer,
40-
)
37+
if self.config.mode != "bench":
38+
self.experience_buffer = get_buffer_writer(
39+
self.config.buffer.explorer_output, # type: ignore
40+
self.config.buffer,
41+
)
4142
self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0)
4243
self.taskset = get_buffer_reader(
4344
self.config.buffer.explorer_input.taskset, self.config.buffer
@@ -261,6 +262,29 @@ def wait():
261262
self.monitor.log(log_metrics, step=self.step_num) # type: ignore
262263
return True, self.step_num
263264

265+
def benchmark(self) -> bool:
266+
"""Benchmark the model checkpoints."""
267+
# benchmark on the latest checkpoint
268+
if self.config.global_config.eval_on_latest_ckp:
269+
self._checkpoint_weights_update()
270+
self.eval()
271+
return True
272+
273+
# benchmark on all checkoints
274+
all_ckp_steps = sorted(
275+
[
276+
int(ckp.split("global_step_")[-1])
277+
for ckp in os.listdir(self.config.model.checkpoint_path)
278+
if os.path.isdir(os.path.join(self.config.model.checkpoint_path, ckp))
279+
and ckp.startswith("global_step_")
280+
]
281+
)
282+
for step_num in all_ckp_steps:
283+
self.step_num = step_num
284+
self._checkpoint_weights_update(step_num=step_num)
285+
self.eval()
286+
return True
287+
264288
def sync_weight(self) -> None:
265289
"""Synchronize model weights."""
266290
# call this method before training start to load the latest model weights
@@ -272,3 +296,6 @@ def sync_weight(self) -> None:
272296
def flush_log(self, step: int) -> None:
273297
"""Flush the log of the current step."""
274298
self.monitor.log({}, step=step, commit=True)
299+
300+
def shutdown(self) -> None:
301+
self.monitor.close()

trinity/trainer/trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
88
"""
9+
import os
910
from abc import ABC, abstractmethod
1011
from typing import Tuple
1112

@@ -59,7 +60,7 @@ def train_one_period(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tupl
5960
train_status, train_step_num = self.train_step(algo_type)
6061
if not train_status:
6162
return False, train_step_num
62-
self.logger.info(f"Trainer steps {train_step_num} finished.")
63+
self.logger.info(f"Train step {train_step_num} finished.")
6364
return True, train_step_num
6465

6566
def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
@@ -119,6 +120,14 @@ def flush_log(self, step: int) -> None:
119120
"""Flush the log of the current step."""
120121
self.engine.logger.log({}, step=step, commit=True)
121122

123+
def shutdown(self) -> None:
124+
# if checkpoint not saved, save the last checkpoint
125+
step_num = self.engine.global_steps - 1
126+
path = os.path.join(self.config.model.checkpoint_path, f"global_step_{step_num}")
127+
if not os.path.isdir(path) or len(os.listdir(path)) == 0:
128+
self.engine.save_checkpoint()
129+
self.engine.logger.close()
130+
122131

123132
class TrainEngineWrapper(ABC):
124133
"""A wrapper class to wrap various training engines."""

trinity/utils/monitor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
5454
"""Log metrics."""
5555
self.logger.log(data, step=step, commit=commit)
5656

57+
def close(self) -> None:
58+
self.logger.close()
59+
5760

5861
class TensorboardLogger:
5962
def __init__(self, project: str, name: str, role: str, config: Any = None) -> None:
@@ -70,6 +73,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
7073
for key in data:
7174
self.logger.add_scalar(key, data[key], step)
7275

76+
def close(self) -> None:
77+
self.logger.close()
78+
7379
def __del__(self) -> None:
7480
self.logger.close()
7581

@@ -95,5 +101,8 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
95101
self.logger.log(data, step=step, commit=commit)
96102
self.console_logger.info(f"Step {step}: {data}")
97103

104+
def close(self) -> None:
105+
self.logger.finish()
106+
98107
def __del__(self) -> None:
99108
self.logger.finish()

0 commit comments

Comments
 (0)