Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ global_config:
total_epochs: 1
batch_size: 96
eval_interval: 1000
eval_on_latest_ckp: true
```

- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
- `global_config.eval_on_latest_ckp`: In bench mode, whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Default is `true`.


## Monitor
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray[default]==2.43.0",
"ray[default]>=2.45.0",
"vllm>=0.8.5",
"tensordict==0.6.2",
"wandb",
Expand Down
13 changes: 9 additions & 4 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def bench(config: Config) -> None:
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
_, step = ray.get(explorer.eval.remote())
logger.info("Evaluation finished.")
ray.get(explorer.flush_log.remote(step=step))
ray.get(explorer.benchmark.remote())
logger.info("Benchmark finished.")
ray.get(explorer.shutdown.remote())
except Exception as e:
logger.error(f"Evaluation failed: {e}")
logger.error(f"Benchmark failed: {e}")
raise e


Expand All @@ -35,6 +35,7 @@ def explore(config: Config) -> None:
ray.get(explorer.sync_weight.remote())
ray.get(explorer.explore.remote())
logger.info("Explore finished.")
ray.get(explorer.shutdown.remote())
except Exception as e:
logger.error(f"Explore failed: {e}")
raise e
Expand All @@ -60,6 +61,7 @@ def train(config: Config) -> None:
try:
ray.get(trainer.train.remote(algo_type))
logger.info("Train finished.")
ray.get(trainer.shutdown.remote())
except Exception as e:
logger.error(f"Train failed {e}.")
raise e
Expand Down Expand Up @@ -133,6 +135,9 @@ def both(config: Config) -> None:
ray.get(explorer.flush_log.remote(step=explore_step_num))
ray.get(trainer.flush_log.remote(step=train_step_num))

ray.get(explorer.shutdown.remote())
ray.get(trainer.shutdown.remote())


def activate_data_module(data_workflow_url: str, config_path: str):
"""Check whether to activate data module and preprocess datasets."""
Expand Down
6 changes: 4 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class GlobalConfig:
total_epochs: int = 1
batch_size: int = 1
eval_interval: int = 100
eval_on_latest_ckp: bool = True


@dataclass
Expand Down Expand Up @@ -299,7 +300,8 @@ def _check_interval(self) -> None:

# check eval_interval
if (
self.trainer.algorithm_type != AlgorithmType.DPO
self.mode != "bench"
and self.trainer.algorithm_type != AlgorithmType.DPO
and self.global_config.eval_interval % self.synchronizer.sync_interval != 0
):
self.global_config.eval_interval = (
Expand All @@ -316,7 +318,7 @@ def _check_interval(self) -> None:
):
if self.trainer.save_interval != self.synchronizer.sync_interval:
logger.warning(
f"When `trainer.algorithm_type != DPO` and `synchronizer.sync_method == checkpoint`, "
f"When `trainer.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, "
f"`trainer.save_interval` will be set to "
f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`."
)
Expand Down
23 changes: 23 additions & 0 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,26 @@ def wait():
self.monitor.log(log_metrics, step=self.step_num) # type: ignore
return True, self.step_num

def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.global_config.eval_on_latest_ckp:
self.eval()
return True

# benchmark on all checkoints
all_ckp_steps = [
int(ckp.split("global_step_")[-1])
for ckp in os.listdir(self.config.model.checkpoint_path)
if os.path.isdir(os.path.join(self.config.model.checkpoint_path, ckp))
and ckp.startswith("global_step_")
]
for step_num in all_ckp_steps:
self.step_num = step_num
self._checkpoint_weights_update(step_num=step_num)
self.eval()
return True

def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
Expand All @@ -272,3 +292,6 @@ def sync_weight(self) -> None:
def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.monitor.log({}, step=step, commit=True)

def shutdown(self) -> None:
self.monitor.close()
9 changes: 9 additions & 0 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import os
from abc import ABC, abstractmethod
from typing import Tuple

Expand Down Expand Up @@ -119,6 +120,14 @@ def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.engine.logger.log({}, step=step, commit=True)

def shutdown(self) -> None:
# if checkpoint not saved, save the last checkpoint
step_num = self.engine.global_steps
path = os.path.join(self.config.model.checkpoint_path, f"global_step_{step_num}")
if not os.path.isdir(path) or len(os.listdir(path)) == 0:
self.engine.save_checkpoint()
self.engine.logger.close()


class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
Expand Down
9 changes: 9 additions & 0 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step, commit=commit)

def close(self) -> None:
self.logger.close()


class TensorboardLogger:
def __init__(self, project: str, name: str, role: str, config: Any = None) -> None:
Expand All @@ -70,6 +73,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
for key in data:
self.logger.add_scalar(key, data[key], step)

def close(self) -> None:
self.logger.close()

def __del__(self) -> None:
self.logger.close()

Expand All @@ -95,5 +101,8 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
self.logger.log(data, step=step, commit=commit)
self.console_logger.info(f"Step {step}: {data}")

def close(self) -> None:
self.logger.finish()

def __del__(self) -> None:
self.logger.finish()