Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,25 @@ Specifies the algorithm type and its related hyperparameters.
algorithm:
algorithm_type: grpo
repeat_times: 1
gamma: 1.0
lam: 1.0

# The following parameters are optional
# If not specified, they will automatically be set based on the `algorithm_type`
sample_strategy: "default"
advantage_fn: "ppo"
kl_penalty_fn: "none"
kl_loss_fn: "k2"
entropy_loss_fn: "default"
```

- `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`.
- `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`.
- `gamma`: Discount factor for future rewards. Default is `1.0`.
- `lam`: Lambda value for Generalized Advantage Estimation (GAE). Default is `1.0`.

- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer.
- `advantage_fn`: The advantage function used for computing advantages.
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty.
- `kl_loss_fn`: The KL loss function used for computing KL loss.
- `entropy_loss_fn`: The entropy loss function used for computing entropy loss.


---

Expand Down
3 changes: 3 additions & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
from trinity.algorithm.kl_fn import KL_FN, KLFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy

__all__ = [
"AdvantageFn",
Expand All @@ -12,4 +13,6 @@
"KL_FN",
"EntropyLossFn",
"ENTROPY_LOSS_FN",
"SampleStrategy",
"SAMPLE_STRATEGY",
]
22 changes: 9 additions & 13 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel
from trinity.common.config import Config
from trinity.common.constants import SyncMethod
from trinity.common.experience import Experience, Experiences
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

Expand All @@ -31,10 +30,6 @@ class AlgorithmType(ABC, metaclass=ConstantMeta):
can_balance_batch: bool
schema: type

@classmethod
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
return Experiences.gather_experiences(exps, pad_token_id)

@classmethod
def get_default_config(cls) -> Dict:
raise NotImplementedError
Expand Down Expand Up @@ -62,6 +57,7 @@ class SFTAlgorithm(AlgorithmType):
@classmethod
def get_default_config(cls) -> Dict:
return {
"sample_strategy": "default",
"policy_loss_fn": "sft",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
Expand All @@ -83,11 +79,12 @@ class PPOAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 1,
"sample_strategy": "warmup",
"policy_loss_fn": "ppo",
"advantage_fn": "ppo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
"entropy_loss_fn": "default",
}


Expand All @@ -106,11 +103,12 @@ class GRPOAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"policy_loss_fn": "ppo",
"advantage_fn": "grpo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
"entropy_loss_fn": "default",
}


Expand All @@ -129,11 +127,12 @@ class OPMDAlgorithm(AlgorithmType):
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"policy_loss_fn": "opmd",
"advantage_fn": "opmd",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
"entropy_loss_fn": "default",
}


Expand All @@ -148,17 +147,14 @@ class DPOAlgorithm(AlgorithmType):
can_balance_batch: bool = False
schema: type = DPODataModel

@classmethod
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
return Experiences.gather_dpo_experiences(exps, pad_token_id)

@classmethod
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2, # fake repeat times
"sample_strategy": "dpo",
"policy_loss_fn": "dpo",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
"entropy_loss_fn": "default",
}

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def default_args(cls) -> Dict:
return {"entropy_coef": 0.0}


@ENTROPY_LOSS_FN.register_module("basic")
class BasicEntropyLossFn(EntropyLossFn):
@ENTROPY_LOSS_FN.register_module("default")
class DefaultEntropyLossFn(EntropyLossFn):
"""
Basic entropy loss function.
"""
Expand Down
13 changes: 13 additions & 0 deletions trinity/algorithm/sample_strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from trinity.algorithm.sample_strategy.sample_strategy import (
SAMPLE_STRATEGY,
DefaultSampleStrategy,
SampleStrategy,
WarmupSampleStrategy,
)

__all__ = [
"SAMPLE_STRATEGY",
"SampleStrategy",
"DefaultSampleStrategy",
"WarmupSampleStrategy",
]
114 changes: 114 additions & 0 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple

from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto
from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
from trinity.common.experience import Experiences
from trinity.utils.registry import Registry
from trinity.utils.timer import Timer

SAMPLE_STRATEGY = Registry("sample_strategy")


class SampleStrategy(ABC):
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
self.pad_token_id = buffer_config.pad_token_id
self.trainer_type = trainer_type

@abstractmethod
def sample(self, step: int) -> Tuple[Any, Dict, List]:
"""Sample experiences from buffer.

Args:
step (`int`): The step number of current step.

Returns:
`Any`: The sampled experiences.
`Dict`: Metrics for logging.
`List`: Representative experiences for logging.
"""

@classmethod
def default_args(cls) -> dict:
return {}


@SAMPLE_STRATEGY.register_module("warmup")
class WarmupSampleStrategy(SampleStrategy):
"""The default sample strategy."""

def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
super().__init__(buffer_config, trainer_type)
self.exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
)
self.sft_warmup_steps = buffer_config.trainer_input.sft_warmup_steps
if self.sft_warmup_steps > 0 and buffer_config.trainer_input.sft_warmup_dataset is None:
raise ValueError("sft_warmup_dataset is required when sft_warmup_steps > 0")
if buffer_config.trainer_input.sft_warmup_dataset is not None:
self.sft_buffer = get_buffer_reader(
buffer_config.trainer_input.sft_warmup_dataset, buffer_config
)
else:
self.sft_buffer = None

def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
if step <= self.sft_warmup_steps:
exp_list = self.sft_buffer.read()
else:
exp_list = self.exp_buffer.read()
repr_samples = representative_sample(exp_list)
with Timer(metrics, "gather_time"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
if self.trainer_type == "verl":
with Timer(metrics, "convert_time"):
data = to_data_proto(exps)
return data, metrics, repr_samples
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")


@SAMPLE_STRATEGY.register_module("default")
class DefaultSampleStrategy(SampleStrategy):
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
super().__init__(buffer_config, trainer_type)
self.exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
)

def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
exp_list = self.exp_buffer.read()
repr_samples = representative_sample(exp_list)
with Timer(metrics, "gather_time"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
if self.trainer_type == "verl":
with Timer(metrics, "convert_time"):
data = to_data_proto(exps)
return data, metrics, repr_samples
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")


@SAMPLE_STRATEGY.register_module("dpo")
class DPOSampleStrategy(WarmupSampleStrategy):
def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
if step <= self.sft_warmup_steps:
exp_list = self.sft_buffer.read()
else:
exp_list = self.exp_buffer.read()
repr_samples = representative_sample(exp_list)
with Timer(metrics, "gather_time"):
exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore
if self.trainer_type == "verl":
with Timer(metrics, "convert_time"):
data = to_data_proto(exps)
return data, metrics, repr_samples
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
78 changes: 78 additions & 0 deletions trinity/algorithm/sample_strategy/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import random
from typing import List

import numpy as np
import torch
from verl.trainer.ppo.ray_trainer import DataProto

from trinity.common.experience import Experience, Experiences


def to_data_proto(experiences: Experiences) -> DataProto:
attention_mask = experiences.attention_masks
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array(experiences.run_ids),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
"attention_mask": attention_mask.long(),
"response_mask": (
experiences.action_masks[:, experiences.prompt_length :].long()
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
else attention_mask[:, experiences.prompt_length :].long()
),
}
if experiences.rewards is not None:
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
eos_mask_idx = cumsum.argmax(dim=-1)
token_level_rewards[
torch.arange(experiences.batch_size), eos_mask_idx
] = experiences.rewards
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
batch_dict.update(
{
"token_level_scores": token_level_rewards,
"old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
}
)
return DataProto.from_single_dict(batch_dict)


def representative_sample(experiences: List[Experience]) -> List[dict]:
if experiences[0].reward is None:
sample = random.choice(experiences)
return [
{
"prompt": sample.prompt_text,
"response": sample.response_text,
}
]
samples = []
min_reward_sample = None
max_reward_sample = None
for exp in experiences:
if exp.reward is None:
continue
if min_reward_sample is None or exp.reward < min_reward_sample.reward:
min_reward_sample = exp
if max_reward_sample is None or exp.reward > max_reward_sample.reward:
max_reward_sample = exp
if min_reward_sample is not None:
samples.append(
{
"prompt": min_reward_sample.prompt_text,
"response": min_reward_sample.response_text,
"reward": min_reward_sample.reward,
}
)
if max_reward_sample is not None:
samples.append(
{
"prompt": max_reward_sample.prompt_text,
"response": max_reward_sample.response_text,
"reward": max_reward_sample.reward,
}
)
return samples
Loading