Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
from trinity.algorithm.kl_fn import KL_FN, KLFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy

__all__ = [
"AdvantageFn",
Expand All @@ -12,4 +13,6 @@
"KL_FN",
"EntropyLossFn",
"ENTROPY_LOSS_FN",
"SampleStrategy",
"SAMPLE_STRATEGY",
]
5 changes: 5 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class SFTAlgorithm(AlgorithmType):
@classmethod
def get_default_config(cls) -> Dict:
return {
"sample_strategy": "default",
"policy_loss_fn": "sft",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
Expand All @@ -83,6 +84,7 @@ class PPOAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 1,
"sample_strategy": "warmup",
"policy_loss_fn": "ppo",
"advantage_fn": "ppo",
"kl_penalty_fn": "none",
Expand All @@ -106,6 +108,7 @@ class GRPOAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"policy_loss_fn": "ppo",
"advantage_fn": "grpo",
"kl_penalty_fn": "none",
Expand All @@ -129,6 +132,7 @@ class OPMDAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"policy_loss_fn": "opmd",
"advantage_fn": "opmd",
"kl_penalty_fn": "none",
Expand Down Expand Up @@ -156,6 +160,7 @@ def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Exp
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2, # fake repeat times
"sample_strategy": "warmup",
"policy_loss_fn": "dpo",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
Expand Down
13 changes: 13 additions & 0 deletions trinity/algorithm/sample_strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from trinity.algorithm.sample_strategy.sample_strategy import (
SAMPLE_STRATEGY,
DefaultSampleStrategy,
SampleStrategy,
WarmupSampleStrategy,
)

__all__ = [
"SAMPLE_STRATEGY",
"SampleStrategy",
"DefaultSampleStrategy",
"WarmupSampleStrategy",
]
64 changes: 64 additions & 0 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from abc import ABC, abstractmethod
from typing import List

from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
from trinity.common.experience import Experience
from trinity.utils.registry import Registry

SAMPLE_STRATEGY = Registry("sample_strategy")


class SampleStrategy(ABC):
def __init__(self, buffer_config: BufferConfig, **kwargs):
self.buffer_config = buffer_config

@abstractmethod
def sample(self, step: int, **kwargs) -> List[Experience]:
"""Sample experiences from buffer.

Args:
step (`int`): The step number of current step.
"""

@classmethod
def default_args(cls) -> dict:
return {}


@SAMPLE_STRATEGY.register_module("warmup")
class WarmupSampleStrategy(SampleStrategy):
"""The default sample strategy."""

def __init__(self, buffer_config: BufferConfig, **kwargs):
super().__init__(buffer_config)
self.exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
)
self.sft_warmup_step = buffer_config.trainer_input.sft_warmup_steps
if self.sft_warmup_step > 0 and buffer_config.trainer_input.sft_warmup_dataset is None:
raise ValueError("sft_warmup_dataset is required when sft_warmup_steps > 0")
if buffer_config.trainer_input.sft_warmup_dataset is not None:
self.sft_buffer = get_buffer_reader(
buffer_config.trainer_input.sft_warmup_dataset, buffer_config
)
else:
self.sft_buffer = None

def sample(self, step: int, **kwargs) -> List[Experience]:
if step <= self.sft_warmup_step:
return self.sft_buffer.read()
else:
return self.exp_buffer.read()


@SAMPLE_STRATEGY.register_module("default")
class DefaultSampleStrategy(SampleStrategy):
def __init__(self, buffer_config: BufferConfig, **kwargs):
super().__init__(buffer_config)
self.exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
)

def sample(self, step: int, **kwargs) -> List[Experience]:
return self.exp_buffer.read()
18 changes: 15 additions & 3 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,8 @@ class AlgorithmConfig:
# for GRPO-like algorithms, repeat each task for `repeat_times` times
repeat_times: int = 1

policy_loss_fn: Optional[str] = None # "ppo"
# If not set, use PolicyLossFn.default_args()
policy_loss_fn_args: Optional[dict] = None
sample_strategy: Optional[str] = None
sample_strategy_args: Optional[dict] = None

advantage_fn: Optional[str] = None # "ppo"
# If not set, use AdvantageFn.default_args()
Expand All @@ -188,6 +187,10 @@ class AlgorithmConfig:
# If not set, use kl_penalty_fn.default_args()
kl_penalty_fn_args: Optional[dict] = None

policy_loss_fn: Optional[str] = None # "ppo"
# If not set, use PolicyLossFn.default_args()
policy_loss_fn_args: Optional[dict] = None

kl_loss_fn: Optional[str] = None # "k2" # set to "none" to disable kl loss
# If not set, use kl_loss_fn.default_args()
kl_loss_fn_args: Optional[dict] = None
Expand Down Expand Up @@ -489,12 +492,14 @@ def _check_algorithm(self) -> None:
ENTROPY_LOSS_FN,
KL_FN,
POLICY_LOSS_FN,
SAMPLE_STRATEGY,
)
from trinity.algorithm.algorithm import ALGORITHM_TYPE

algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type)
algorithm.check_config(self)
default_config = {
"sample_strategy": "warmup",
"policy_loss_fn": "ppo",
"advantage_fn": "ppo",
"kl_penalty_fn": "none",
Expand All @@ -506,6 +511,13 @@ def _check_algorithm(self) -> None:
if getattr(self.algorithm, key, None) is None:
setattr(self.algorithm, key, value)

# TODO: simplify the following code
sample_strategy_cls = SAMPLE_STRATEGY.get(self.algorithm.sample_strategy)
if sample_strategy_cls is None:
raise ValueError(f"Invalid sample_strategy: {self.algorithm.sample_strategy}")
if self.algorithm.sample_strategy_args is None:
self.algorithm.sample_strategy_args = sample_strategy_cls.default_args()

policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
if policy_fn_cls is None:
raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}")
Expand Down
27 changes: 6 additions & 21 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

import ray

from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm
from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.algorithm.algorithm_manager import AlgorithmManager
from trinity.buffer import get_buffer_reader
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY
from trinity.common.config import Config
from trinity.common.constants import SyncMethod
from trinity.utils.log import get_logger
Expand All @@ -28,17 +28,9 @@ def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
self.algorithm_manager = AlgorithmManager(config)
self.train_buffer = get_buffer_reader(
self.config.buffer.trainer_input.experience_buffer, # type: ignore
self.config.buffer,
)
self.sft_warmup_buffer = (
get_buffer_reader(
self.config.buffer.trainer_input.sft_warmup_dataset, # type: ignore
self.config.buffer,
)
if self.config.buffer.trainer_input.sft_warmup_steps > 0
else None
self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)(
buffer_config=config.buffer,
**config.algorithm.sample_strategy_args,
)
self.engine = get_trainer_wrapper(config)

Expand Down Expand Up @@ -76,15 +68,8 @@ def train_step(self) -> Tuple[bool, int]:
)
algo_type = algo_config.algorithm_type
algorithm = ALGORITHM_TYPE.get(algo_type)
if algorithm.use_rollout:
strategy = self.config.buffer.trainer_input.read_experience_strategy
else:
strategy = None
try:
if algorithm == SFTAlgorithm:
exps = self.sft_warmup_buffer.read()
else:
exps = self.train_buffer.read(strategy=strategy)
exps = self.sample_strategy.sample(self.engine.train_step_num + 1)
except StopIteration:
self.logger.warning("No more data to train. Stop training.")
return False, self.engine.train_step_num
Expand Down