Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import ray

from trinity.algorithm import SAMPLE_STRATEGY
from trinity.common.config import Config
from trinity.common.constants import RunningStatus, SyncMethod
from trinity.utils.log import get_logger
Expand All @@ -23,6 +24,11 @@ def __init__(self, config: Config) -> None:
self.logger = get_logger(__name__)
self.engine = get_trainer_wrapper(config)
self.explorer_ref = None
self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)(
buffer_config=config.buffer,
trainer_type=config.trainer.trainer_type,
**config.algorithm.sample_strategy_args,
)

def prepare(self) -> None:
"""Prepare the trainer."""
Expand All @@ -32,7 +38,30 @@ def train(self) -> str:
"""Train the model."""
while True:
try:
train_continue = self.train_step()
# sample experiences for train step
try:
batch, sample_metrics, exp_samples = self.sample_strategy.sample(
self.engine.global_steps + 1,
)
successful_sampling = True
except StopIteration:
print("No more data to train. Stop training.")
if (
self.engine.config.trainer.save_freq == 0
or self.engine.global_steps % self.engine.config.trainer.save_freq != 0
): # TODO: double-check this if-condition
self.engine.logger.info(f"Saving at step {self.engine.global_steps}.")
self.engine.save_checkpoint()
self.engine.logger.info(f"Saved at step {self.engine.global_steps}.")
successful_sampling = False
# TODO: get rid of self.engine.global_steps/config/logger?

# run train step
if successful_sampling:
train_continue = self.engine.train_step(batch, sample_metrics, exp_samples)
else:
train_continue = False

if not train_continue:
break
if self.need_sync():
Expand All @@ -43,14 +72,6 @@ def train(self) -> str:
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
return self.config.trainer.name

def train_step(self) -> bool:
"""Train one step.

Returns:
bool: Whether to continue training.
"""
return self.engine.train_step()

def need_sync(self) -> bool:
"""Whether to sync the model weight."""
return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0
Expand Down Expand Up @@ -95,7 +116,7 @@ def train_step_num(self) -> int:
"""Get the current training step number."""

@abstractmethod
def train_step(self) -> bool:
def train_step(self, batch, sample_metrics, exp_samples) -> bool:
"""Training."""

@abstractmethod
Expand Down
43 changes: 23 additions & 20 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_local_path_from_hdfs

from trinity.algorithm import ADVANTAGE_FN, KL_FN, SAMPLE_STRATEGY
from trinity.algorithm import ADVANTAGE_FN, KL_FN
from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm
from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.algorithm.utils import prefix_metrics
Expand Down Expand Up @@ -134,11 +134,11 @@ def __init__(
self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)(
**self.algorithm_config.kl_penalty_fn_args
)
self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)(
buffer_config=global_config.buffer,
trainer_type=global_config.trainer.trainer_type,
**global_config.algorithm.sample_strategy_args,
)
# self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)(
# buffer_config=global_config.buffer,
# trainer_type=global_config.trainer.trainer_type,
# **global_config.algorithm.sample_strategy_args,
# )
super().__init__(
config,
tokenizer,
Expand Down Expand Up @@ -287,22 +287,25 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
# TODO: compute total training steps
self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize

def train_step(self) -> bool: # noqa C901
def train_step(self, batch, sample_metrics, exp_samples) -> bool: # noqa C901
self.logger.info(f"Training at step {self.global_steps + 1} started.")
metrics = {}
try:
batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1)
prefix_metrics(sample_metrics, "sample", metrics)
except StopIteration:
print("No more data to train. Stop training.")
if (
self.config.trainer.save_freq == 0
or self.global_steps % self.config.trainer.save_freq != 0
):
self.logger.info(f"Saving at step {self.global_steps}.")
self._save_checkpoint()
self.logger.info(f"Saved at step {self.global_steps}.")
return False
prefix_metrics(sample_metrics, "sample", metrics)

# try:
# batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1)
# prefix_metrics(sample_metrics, "sample", metrics)
# except StopIteration:
# print("No more data to train. Stop training.")
# if (
# self.config.trainer.save_freq == 0
# or self.global_steps % self.config.trainer.save_freq != 0
# ):
# self.logger.info(f"Saving at step {self.global_steps}.")
# self._save_checkpoint()
# self.logger.info(f"Saved at step {self.global_steps}.")
# return False

self.global_steps += 1
self.logger.info(f"Sampling at step {self.global_steps} done.")
timing_raw = {}
Expand Down