Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@
logger = get_logger(__name__)


def bench(config: Config) -> None:
"""Evaluate model."""
explorer = Explorer.remote(config)
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
_, step = ray.get(explorer.eval.remote())
logger.info("Evaluation finished.")
ray.get(explorer.flush_log.remote(step=step))
except Exception as e:
logger.error(f"Evaluation failed: {e}")
raise e


def explore(config: Config) -> None:
"""Run explorer."""
explorer = Explorer.remote(config)
Expand Down Expand Up @@ -151,6 +165,8 @@ def run(config_path: str):
train(config)
elif config.mode == "both":
both(config)
elif config.mode == "bench":
bench(config)


def studio(port: int = 8501):
Expand Down
9 changes: 7 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class SynchronizerConfig:
class Config:
"""Global Configuration"""

mode: str = "both" # `explore`, `train` or `both`
mode: str = "both" # `explore`, `train`, `both` or `bench`
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
cluster: ClusterConfig = field(default_factory=ClusterConfig)
Expand Down Expand Up @@ -302,7 +302,7 @@ def _check_buffer(self) -> None:
def check_and_update(self) -> None: # noqa: C901
"""Check and update the config."""
# check mode
if self.mode not in ["explore", "train", "both"]:
if self.mode not in ["explore", "train", "both", "bench"]:
raise ValueError(f"Invalid mode: {self.mode}")
if self.trainer.algorithm_type == AlgorithmType.DPO and self.mode == "both":
raise ValueError("DPO does not support `both` mode")
Expand All @@ -325,6 +325,11 @@ def check_and_update(self) -> None: # noqa: C901
self.explorer.engine_num * self.explorer.tensor_parallel_size
)
self.synchronizer.backend = self.explorer.backend
if self.mode == "bench" and self.synchronizer.sync_method != SyncMethod.CHECKPOINT:
self.synchronizer.sync_method = "checkpoint"
logger.warning(
"Bench mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
if (
self.trainer.algorithm_type == AlgorithmType.DPO
and self.synchronizer.sync_method != SyncMethod.CHECKPOINT
Expand Down
4 changes: 3 additions & 1 deletion trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,11 @@ def synchronize_config(self, config: Config) -> None:
self.actor_rollout_ref.actor.use_kl_loss = True
logger.warning("DPO must use KL loss.")
logger.warning("DPO micro batch size is doubled for computing loss.")
self.actor_rollout_ref.actor.ppo_mini_batch_size *= 2
self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 # type: ignore
self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 # type: ignore
if self.actor_rollout_ref.rollout.n != 2:
self.actor_rollout_ref.rollout.n = 2
logger.warning("In DPO, actor_rollout_ref.rollout.n is set to 2.")
# TODO: check other fields
self.enable_preview = config.trainer.enable_preview

Expand Down
6 changes: 3 additions & 3 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ def explore_one_period(self) -> Tuple[bool, int]:
self.logger.info(f"Explore step {self.step_num} finished.")
return True, self.step_num

def eval(self) -> bool:
def eval(self) -> Tuple[bool, int]:
"""Evaluation on all evaluation data samples."""
if self.eval_taskset is None:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return True
return True, self.step_num
self.logger.info("Evaluation started.")
st = time.time()
all_metrics = defaultdict(list)
Expand All @@ -255,7 +255,7 @@ def eval(self) -> bool:
log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore
log_metrics["eval/total_time"] = time.time() - st
self.monitor.log(log_metrics, step=self.step_num) # type: ignore
return True
return True, self.step_num

def sync_weight(self) -> None:
"""Synchronize model weights."""
Expand Down
24 changes: 13 additions & 11 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,32 +73,34 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool
bool: Whether to continue training.
"""
self.engine.set_mode(algo_type)
if algo_type.is_rft() and self.config.trainer.get_exp_strategy:
strategy = ReadStrategy(self.config.trainer.get_exp_strategy)
else:
strategy = None
try:
if algo_type.is_sft():
exps = self.sft_warmup_buffer.read()
else:
exps = self.train_buffer.read(strategy=strategy)
except StopIteration:
self.logger.warning("No more data to train. Stop training.")
return False, 0 # TODO: get the actual step number

if algo_type.is_sft():
exps = self.sft_warmup_buffer.read()
return self.engine.train_sft_step(
Experiences.gather_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
)
elif algo_type.is_rft():
if self.config.trainer.get_exp_strategy:
strategy = ReadStrategy(self.config.trainer.get_exp_strategy)
else:
strategy = None
try:
exps = self.train_buffer.read(strategy=strategy)
except StopIteration:
self.logger.warning("No more data to train. Stop training.")
return False, 0 # TODO: get the actual step number
return self.engine.train_rft_step(
Experiences.gather_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
)
elif algo_type.is_dpo():
exps = self.train_buffer.read()
return self.engine.train_dpo_step(
Experiences.gather_dpo_experiences(
exps,
Expand Down