Skip to content

Commit dc8cb0c

Browse files
authored
Add Sample Strategy (#78)
1 parent 3c759d9 commit dc8cb0c

File tree

11 files changed

+292
-118
lines changed

11 files changed

+292
-118
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,25 @@ Specifies the algorithm type and its related hyperparameters.
7979
algorithm:
8080
algorithm_type: grpo
8181
repeat_times: 1
82-
gamma: 1.0
83-
lam: 1.0
82+
83+
# The following parameters are optional
84+
# If not specified, they will automatically be set based on the `algorithm_type`
85+
sample_strategy: "default"
86+
advantage_fn: "ppo"
87+
kl_penalty_fn: "none"
88+
kl_loss_fn: "k2"
89+
entropy_loss_fn: "default"
8490
```
8591
8692
- `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`.
8793
- `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`.
88-
- `gamma`: Discount factor for future rewards. Default is `1.0`.
89-
- `lam`: Lambda value for Generalized Advantage Estimation (GAE). Default is `1.0`.
94+
95+
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer.
96+
- `advantage_fn`: The advantage function used for computing advantages.
97+
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty.
98+
- `kl_loss_fn`: The KL loss function used for computing KL loss.
99+
- `entropy_loss_fn`: The entropy loss function used for computing entropy loss.
100+
90101

91102
---
92103

trinity/algorithm/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
33
from trinity.algorithm.kl_fn import KL_FN, KLFn
44
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
5+
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy
56

67
__all__ = [
78
"AdvantageFn",
@@ -12,4 +13,6 @@
1213
"KL_FN",
1314
"EntropyLossFn",
1415
"ENTROPY_LOSS_FN",
16+
"SampleStrategy",
17+
"SAMPLE_STRATEGY",
1518
]

trinity/algorithm/algorithm.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel
88
from trinity.common.config import Config
99
from trinity.common.constants import SyncMethod
10-
from trinity.common.experience import Experience, Experiences
1110
from trinity.utils.log import get_logger
1211
from trinity.utils.registry import Registry
1312

@@ -31,10 +30,6 @@ class AlgorithmType(ABC, metaclass=ConstantMeta):
3130
can_balance_batch: bool
3231
schema: type
3332

34-
@classmethod
35-
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
36-
return Experiences.gather_experiences(exps, pad_token_id)
37-
3833
@classmethod
3934
def get_default_config(cls) -> Dict:
4035
raise NotImplementedError
@@ -62,6 +57,7 @@ class SFTAlgorithm(AlgorithmType):
6257
@classmethod
6358
def get_default_config(cls) -> Dict:
6459
return {
60+
"sample_strategy": "default",
6561
"policy_loss_fn": "sft",
6662
"kl_loss_fn": "none",
6763
"entropy_loss_fn": "none",
@@ -83,11 +79,12 @@ class PPOAlgorithm(AlgorithmType):
8379
def get_default_config(cls) -> Dict:
8480
return {
8581
"repeat_times": 1,
82+
"sample_strategy": "warmup",
8683
"policy_loss_fn": "ppo",
8784
"advantage_fn": "ppo",
8885
"kl_penalty_fn": "none",
8986
"kl_loss_fn": "k2",
90-
"entropy_loss_fn": "basic",
87+
"entropy_loss_fn": "default",
9188
}
9289

9390

@@ -106,11 +103,12 @@ class GRPOAlgorithm(AlgorithmType):
106103
def get_default_config(cls) -> Dict:
107104
return {
108105
"repeat_times": 2,
106+
"sample_strategy": "warmup",
109107
"policy_loss_fn": "ppo",
110108
"advantage_fn": "grpo",
111109
"kl_penalty_fn": "none",
112110
"kl_loss_fn": "k2",
113-
"entropy_loss_fn": "basic",
111+
"entropy_loss_fn": "default",
114112
}
115113

116114

@@ -129,11 +127,12 @@ class OPMDAlgorithm(AlgorithmType):
129127
def get_default_config(cls) -> Dict:
130128
return {
131129
"repeat_times": 2,
130+
"sample_strategy": "warmup",
132131
"policy_loss_fn": "opmd",
133132
"advantage_fn": "opmd",
134133
"kl_penalty_fn": "none",
135134
"kl_loss_fn": "k2",
136-
"entropy_loss_fn": "basic",
135+
"entropy_loss_fn": "default",
137136
}
138137

139138

@@ -148,17 +147,14 @@ class DPOAlgorithm(AlgorithmType):
148147
can_balance_batch: bool = False
149148
schema: type = DPODataModel
150149

151-
@classmethod
152-
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
153-
return Experiences.gather_dpo_experiences(exps, pad_token_id)
154-
155150
@classmethod
156151
def get_default_config(cls) -> Dict:
157152
return {
158153
"repeat_times": 2, # fake repeat times
154+
"sample_strategy": "dpo",
159155
"policy_loss_fn": "dpo",
160156
"kl_loss_fn": "k2",
161-
"entropy_loss_fn": "basic",
157+
"entropy_loss_fn": "default",
162158
}
163159

164160
@classmethod

trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def default_args(cls) -> Dict:
4040
return {"entropy_coef": 0.0}
4141

4242

43-
@ENTROPY_LOSS_FN.register_module("basic")
44-
class BasicEntropyLossFn(EntropyLossFn):
43+
@ENTROPY_LOSS_FN.register_module("default")
44+
class DefaultEntropyLossFn(EntropyLossFn):
4545
"""
4646
Basic entropy loss function.
4747
"""
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from trinity.algorithm.sample_strategy.sample_strategy import (
2+
SAMPLE_STRATEGY,
3+
DefaultSampleStrategy,
4+
SampleStrategy,
5+
WarmupSampleStrategy,
6+
)
7+
8+
__all__ = [
9+
"SAMPLE_STRATEGY",
10+
"SampleStrategy",
11+
"DefaultSampleStrategy",
12+
"WarmupSampleStrategy",
13+
]
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, List, Tuple
3+
4+
from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto
5+
from trinity.buffer import get_buffer_reader
6+
from trinity.common.config import BufferConfig
7+
from trinity.common.experience import Experiences
8+
from trinity.utils.registry import Registry
9+
from trinity.utils.timer import Timer
10+
11+
SAMPLE_STRATEGY = Registry("sample_strategy")
12+
13+
14+
class SampleStrategy(ABC):
15+
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
16+
self.pad_token_id = buffer_config.pad_token_id
17+
self.trainer_type = trainer_type
18+
19+
@abstractmethod
20+
def sample(self, step: int) -> Tuple[Any, Dict, List]:
21+
"""Sample experiences from buffer.
22+
23+
Args:
24+
step (`int`): The step number of current step.
25+
26+
Returns:
27+
`Any`: The sampled experiences.
28+
`Dict`: Metrics for logging.
29+
`List`: Representative experiences for logging.
30+
"""
31+
32+
@classmethod
33+
def default_args(cls) -> dict:
34+
return {}
35+
36+
37+
@SAMPLE_STRATEGY.register_module("warmup")
38+
class WarmupSampleStrategy(SampleStrategy):
39+
"""The default sample strategy."""
40+
41+
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
42+
super().__init__(buffer_config, trainer_type)
43+
self.exp_buffer = get_buffer_reader(
44+
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
45+
)
46+
self.sft_warmup_steps = buffer_config.trainer_input.sft_warmup_steps
47+
if self.sft_warmup_steps > 0 and buffer_config.trainer_input.sft_warmup_dataset is None:
48+
raise ValueError("sft_warmup_dataset is required when sft_warmup_steps > 0")
49+
if buffer_config.trainer_input.sft_warmup_dataset is not None:
50+
self.sft_buffer = get_buffer_reader(
51+
buffer_config.trainer_input.sft_warmup_dataset, buffer_config
52+
)
53+
else:
54+
self.sft_buffer = None
55+
56+
def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
57+
metrics = {}
58+
with Timer(metrics, "read_time"):
59+
if step <= self.sft_warmup_steps:
60+
exp_list = self.sft_buffer.read()
61+
else:
62+
exp_list = self.exp_buffer.read()
63+
repr_samples = representative_sample(exp_list)
64+
with Timer(metrics, "gather_time"):
65+
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
66+
if self.trainer_type == "verl":
67+
with Timer(metrics, "convert_time"):
68+
data = to_data_proto(exps)
69+
return data, metrics, repr_samples
70+
else:
71+
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
72+
73+
74+
@SAMPLE_STRATEGY.register_module("default")
75+
class DefaultSampleStrategy(SampleStrategy):
76+
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
77+
super().__init__(buffer_config, trainer_type)
78+
self.exp_buffer = get_buffer_reader(
79+
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
80+
)
81+
82+
def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
83+
metrics = {}
84+
with Timer(metrics, "read_time"):
85+
exp_list = self.exp_buffer.read()
86+
repr_samples = representative_sample(exp_list)
87+
with Timer(metrics, "gather_time"):
88+
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
89+
if self.trainer_type == "verl":
90+
with Timer(metrics, "convert_time"):
91+
data = to_data_proto(exps)
92+
return data, metrics, repr_samples
93+
else:
94+
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
95+
96+
97+
@SAMPLE_STRATEGY.register_module("dpo")
98+
class DPOSampleStrategy(WarmupSampleStrategy):
99+
def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
100+
metrics = {}
101+
with Timer(metrics, "read_time"):
102+
if step <= self.sft_warmup_steps:
103+
exp_list = self.sft_buffer.read()
104+
else:
105+
exp_list = self.exp_buffer.read()
106+
repr_samples = representative_sample(exp_list)
107+
with Timer(metrics, "gather_time"):
108+
exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore
109+
if self.trainer_type == "verl":
110+
with Timer(metrics, "convert_time"):
111+
data = to_data_proto(exps)
112+
return data, metrics, repr_samples
113+
else:
114+
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import random
2+
from typing import List
3+
4+
import numpy as np
5+
import torch
6+
from verl.trainer.ppo.ray_trainer import DataProto
7+
8+
from trinity.common.experience import Experience, Experiences
9+
10+
11+
def to_data_proto(experiences: Experiences) -> DataProto:
12+
attention_mask = experiences.attention_masks
13+
cumsum = torch.cumsum(attention_mask, dim=-1)
14+
position_ids = torch.clip(cumsum - 1, 0, None).long()
15+
batch_dict = {
16+
"uid": np.array(experiences.run_ids),
17+
"position_ids": position_ids,
18+
"input_ids": experiences.tokens.long(),
19+
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
20+
"attention_mask": attention_mask.long(),
21+
"response_mask": (
22+
experiences.action_masks[:, experiences.prompt_length :].long()
23+
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
24+
else attention_mask[:, experiences.prompt_length :].long()
25+
),
26+
}
27+
if experiences.rewards is not None:
28+
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
29+
eos_mask_idx = cumsum.argmax(dim=-1)
30+
token_level_rewards[
31+
torch.arange(experiences.batch_size), eos_mask_idx
32+
] = experiences.rewards
33+
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
34+
batch_dict.update(
35+
{
36+
"token_level_scores": token_level_rewards,
37+
"old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
38+
}
39+
)
40+
return DataProto.from_single_dict(batch_dict)
41+
42+
43+
def representative_sample(experiences: List[Experience]) -> List[dict]:
44+
if experiences[0].reward is None:
45+
sample = random.choice(experiences)
46+
return [
47+
{
48+
"prompt": sample.prompt_text,
49+
"response": sample.response_text,
50+
}
51+
]
52+
samples = []
53+
min_reward_sample = None
54+
max_reward_sample = None
55+
for exp in experiences:
56+
if exp.reward is None:
57+
continue
58+
if min_reward_sample is None or exp.reward < min_reward_sample.reward:
59+
min_reward_sample = exp
60+
if max_reward_sample is None or exp.reward > max_reward_sample.reward:
61+
max_reward_sample = exp
62+
if min_reward_sample is not None:
63+
samples.append(
64+
{
65+
"prompt": min_reward_sample.prompt_text,
66+
"response": min_reward_sample.response_text,
67+
"reward": min_reward_sample.reward,
68+
}
69+
)
70+
if max_reward_sample is not None:
71+
samples.append(
72+
{
73+
"prompt": max_reward_sample.prompt_text,
74+
"response": max_reward_sample.response_text,
75+
"reward": max_reward_sample.reward,
76+
}
77+
)
78+
return samples

0 commit comments

Comments
 (0)