Skip to content

Commit 39ff2e4

Browse files
committed
add benchmark mode
1 parent 2086bec commit 39ff2e4

File tree

4 files changed

+34
-6
lines changed

4 files changed

+34
-6
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ global_config:
1010
total_epochs: 1
1111
batch_size: 96
1212
eval_interval: 1000
13+
eval_on_latest_ckp: true
1314
```
1415
1516
- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
1617
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
1718
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
1819
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
20+
- `global_config.eval_on_latest_ckp`: In bench mode, whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Default is `true`.
1921

2022

2123
## Monitor

trinity/cli/launcher.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44

55
import ray
6+
import wandb
67

78
from trinity.common.config import Config, load_config
89
from trinity.common.constants import AlgorithmType
@@ -19,11 +20,12 @@ def bench(config: Config) -> None:
1920
try:
2021
ray.get(explorer.prepare.remote())
2122
ray.get(explorer.sync_weight.remote())
22-
_, step = ray.get(explorer.eval.remote())
23-
logger.info("Evaluation finished.")
24-
ray.get(explorer.flush_log.remote(step=step))
23+
bm_finished, step = ray.get(explorer.benchmark.remote())
24+
logger.info("Benchmark finished.")
25+
if bm_finished:
26+
ray.get(explorer.flush_log.remote(step=step))
2527
except Exception as e:
26-
logger.error(f"Evaluation failed: {e}")
28+
logger.error(f"Benchmark failed: {e}")
2729
raise e
2830

2931

@@ -168,6 +170,9 @@ def run(config_path: str):
168170
elif config.mode == "bench":
169171
bench(config)
170172

173+
if config.monitor.monitor_type == "wandb":
174+
wandb.finish()
175+
171176

172177
def studio(port: int = 8501):
173178
from streamlit.web import cli as stcli

trinity/common/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class GlobalConfig:
107107
total_epochs: int = 1
108108
batch_size: int = 1
109109
eval_interval: int = 100
110+
eval_on_latest_ckp: bool = True
110111

111112

112113
@dataclass
@@ -299,7 +300,8 @@ def _check_interval(self) -> None:
299300

300301
# check eval_interval
301302
if (
302-
self.trainer.algorithm_type != AlgorithmType.DPO
303+
self.mode != "bench"
304+
and self.trainer.algorithm_type != AlgorithmType.DPO
303305
and self.global_config.eval_interval % self.synchronizer.sync_interval != 0
304306
):
305307
self.global_config.eval_interval = (
@@ -316,7 +318,7 @@ def _check_interval(self) -> None:
316318
):
317319
if self.trainer.save_interval != self.synchronizer.sync_interval:
318320
logger.warning(
319-
f"When `trainer.algorithm_type != DPO` and `synchronizer.sync_method == checkpoint`, "
321+
f"When `trainer.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, "
320322
f"`trainer.save_interval` will be set to "
321323
f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`."
322324
)

trinity/explorer/explorer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,25 @@ def wait():
261261
self.monitor.log(log_metrics, step=self.step_num) # type: ignore
262262
return True, self.step_num
263263

264+
def benchmark(self) -> Tuple[bool, int]:
265+
"""Benchmark the model checkpoints."""
266+
latest_step = self.step_num
267+
268+
# benchmark on the latest checkpoint
269+
if self.config.global_config.eval_on_latest_ckp:
270+
self.eval()
271+
return True, self.step_num
272+
273+
# benchmark on all checkoints
274+
for step_num in range(latest_step + 1):
275+
path = os.path.join(self.config.model.checkpoint_path, f"global_step_{step_num}")
276+
if os.path.isdir(path) and os.listdir(path):
277+
self.logger.info(f"{path} exists.")
278+
self.step_num = step_num
279+
self._checkpoint_weights_update(step_num=step_num)
280+
self.eval()
281+
return True, self.step_num
282+
264283
def sync_weight(self) -> None:
265284
"""Synchronize model weights."""
266285
# call this method before training start to load the latest model weights

0 commit comments

Comments
 (0)