Skip to content

Commit e412fbe

Browse files
authored
Add staleness control (#445)
1 parent 85b8577 commit e412fbe

File tree

18 files changed

+367
-42
lines changed

18 files changed

+367
-42
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ algorithm:
112112
- `optimizer`: Optimizer configuration for actor.
113113
- `lr`: Learning rate for actor.
114114
- `warmup_style`: Warmup style for actor's learning rate.
115-
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer.
115+
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`.
116116
- `advantage_fn`: The advantage function used for computing advantages.
117117
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.
118118
- `kl_loss_fn`: The KL loss function used for computing KL loss.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ algorithm:
112112
- `optimizer`: Actor 优化器的参数。
113113
- `lr`: 优化器的学习率。
114114
- `warmup_style`: 学习率的预热策略。
115-
- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。
115+
- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。支持类型:`default`、`staleness_control`、`mix`。
116116
- `advantage_fn`: 用于计算优势值的函数。
117117
- `kl_penalty_fn`: 用于在奖励中计算 KL 惩罚的函数。
118118
- `kl_loss_fn`: 用于计算 KL 损失的函数。

tests/buffer/experience_storage_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ async def test_sql_experience_buffer(self):
108108
prompt_length=i,
109109
reward=float(i),
110110
logprobs=torch.tensor([0.1]),
111+
info={"model_version": 0},
111112
)
112113
for i in range(1, self.put_batch_size + 1)
113114
]
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
import asyncio
2+
import shutil
3+
from collections import deque
4+
5+
import torch
6+
from parameterized import parameterized_class
7+
8+
from tests.tools import RayUnittestBaseAysnc, get_template_config
9+
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY
10+
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
11+
from trinity.buffer.buffer import get_buffer_writer
12+
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
13+
from trinity.common.constants import StorageType
14+
from trinity.common.experience import Experience
15+
16+
17+
@parameterized_class(
18+
("exp_write_batch_size",),
19+
[
20+
(3,),
21+
(6,),
22+
],
23+
)
24+
class ExperienceStorageTest(RayUnittestBaseAysnc):
25+
def setUp(self):
26+
self.config = get_template_config()
27+
self.num_steps = 20
28+
29+
def _default_exp_list(self):
30+
return [
31+
[
32+
Experience(
33+
tokens=torch.tensor([float(k) for k in range(j + 3)]),
34+
reward=float(i), # using reward to carry model_version for testing
35+
prompt_length=2,
36+
info={"model_version": i, "use_count": 0},
37+
)
38+
for j in range(self.exp_write_batch_size)
39+
]
40+
for i in range(self.num_steps)
41+
]
42+
43+
def _default_steps(self):
44+
return [0, 5, 10, 15]
45+
46+
def _init_buffer_writer_and_sample_strategy(self):
47+
# Initialize buffer writer and sample strategy
48+
self.buffer_writer = get_buffer_writer(
49+
self.config.buffer.trainer_input.experience_buffer, # type: ignore [arg-type]
50+
)
51+
self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get(
52+
self.config.algorithm.sample_strategy
53+
)(
54+
buffer_config=self.config.buffer,
55+
**self.config.algorithm.sample_strategy_args,
56+
)
57+
58+
async def _verify_model_version(self, step, expected_versions):
59+
batch, metrics, _ = await self.sample_strategy.sample(step=step)
60+
self.assertEqual(
61+
batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}"
62+
)
63+
self.assertEqual(
64+
metrics["sample/model_version/min"],
65+
min(expected_versions),
66+
f"Min model version mismatch at step {step}",
67+
)
68+
self.assertEqual(
69+
metrics["sample/model_version/max"],
70+
max(expected_versions),
71+
f"Max model version mismatch at step {step}",
72+
)
73+
self.assertEqual(
74+
metrics["sample/model_version/mean"],
75+
sum(expected_versions) / len(expected_versions),
76+
f"Mean model version mismatch at step {step}",
77+
)
78+
79+
async def _verify_sampling_model_versions(self, exps_list, expected_model_versions_map):
80+
self._init_buffer_writer_and_sample_strategy()
81+
82+
# Write experiences to buffer, while sample and validate model versions
83+
current_task = None
84+
for step, exps in enumerate(exps_list):
85+
await self.buffer_writer.write_async(exps)
86+
if step in expected_model_versions_map:
87+
if current_task:
88+
await current_task
89+
current_task = asyncio.create_task(
90+
self._verify_model_version(step, expected_model_versions_map[step])
91+
)
92+
await asyncio.sleep(0.1)
93+
94+
if current_task:
95+
await current_task
96+
97+
async def _flexible_verify_model_version(self, step, max_staleness):
98+
_, metrics, _ = await self.sample_strategy.sample(step=step)
99+
self.assertGreaterEqual(
100+
metrics["sample/model_version/min"],
101+
step - max_staleness,
102+
f"Min model version mismatch at step {step}",
103+
)
104+
105+
async def _flexible_verify_sampling_model_versions(self, exps_list, check_steps, max_staleness):
106+
self._init_buffer_writer_and_sample_strategy()
107+
108+
# Write experiences to buffer, while sample and validate model versions
109+
current_task = None
110+
for step, exps in enumerate(exps_list):
111+
await self.buffer_writer.write_async(exps)
112+
if step in check_steps:
113+
if current_task:
114+
await current_task
115+
current_task = asyncio.create_task(
116+
self._flexible_verify_model_version(step, max_staleness)
117+
)
118+
await asyncio.sleep(0.1)
119+
120+
if current_task:
121+
await current_task
122+
123+
async def test_default_queue_default_sample_strategy(self):
124+
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
125+
name="default_queue_default_strategy",
126+
storage_type=StorageType.QUEUE.value,
127+
replay_buffer=ReplayBufferConfig(enable=False),
128+
)
129+
self.config.check_and_update()
130+
131+
# init testing data
132+
exps_list = self._default_exp_list()
133+
steps = self._default_steps()
134+
train_batch_size = self.config.buffer.train_batch_size
135+
expected_model_versions_map = {}
136+
for idx, step in enumerate(steps):
137+
start_idx = idx * train_batch_size
138+
batch_versions = [
139+
(start_idx + offset) // self.exp_write_batch_size
140+
for offset in range(train_batch_size)
141+
]
142+
expected_model_versions_map[step] = batch_versions
143+
144+
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
145+
146+
async def test_default_queue_staleness_control_sample_strategy(self):
147+
max_staleness = 3
148+
self.config.algorithm.sample_strategy = "staleness_control"
149+
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
150+
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
151+
name="default_queue_staleness_control",
152+
storage_type=StorageType.QUEUE.value,
153+
replay_buffer=ReplayBufferConfig(enable=False),
154+
)
155+
self.config.check_and_update()
156+
157+
# init testing data
158+
exps_list = self._default_exp_list()
159+
steps = self._default_steps()
160+
expected_model_versions_map = {}
161+
for step in steps:
162+
predict_version = max(step - max_staleness, 0)
163+
expected_model_versions_map[step] = [
164+
predict_version + i // self.exp_write_batch_size
165+
for i in range(self.config.buffer.train_batch_size)
166+
]
167+
168+
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
169+
170+
def _simulate_priority_queue(self, steps, max_staleness=float("inf")):
171+
expected_model_versions_map = {}
172+
buffer = deque()
173+
exp_pool = deque()
174+
step_idx = 0
175+
train_batch_size = self.config.buffer.train_batch_size
176+
for i in range(self.num_steps):
177+
buffer.append([i] * self.exp_write_batch_size)
178+
step = steps[step_idx]
179+
if i < step:
180+
continue
181+
batch_versions = expected_model_versions_map.get(step, [])
182+
if len(batch_versions) < train_batch_size:
183+
while len(buffer) > 0:
184+
if len(exp_pool) == 0:
185+
exp_pool.extend(buffer.pop())
186+
while len(exp_pool) > 0 and len(batch_versions) < train_batch_size:
187+
exp_version = exp_pool.popleft()
188+
if exp_version < step - max_staleness:
189+
continue
190+
batch_versions.append(exp_version)
191+
if len(batch_versions) >= train_batch_size:
192+
step_idx += 1
193+
break
194+
expected_model_versions_map[step] = batch_versions
195+
if step_idx >= len(steps):
196+
break
197+
return expected_model_versions_map
198+
199+
async def test_priority_queue_default_sample_strategy(self):
200+
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
201+
name="priority_queue_default_strategy",
202+
storage_type=StorageType.QUEUE.value,
203+
replay_buffer=ReplayBufferConfig(enable=True),
204+
)
205+
self.config.check_and_update()
206+
207+
# init testing data
208+
exps_list = self._default_exp_list()
209+
steps = self._default_steps()
210+
expected_model_versions_map = self._simulate_priority_queue(steps)
211+
212+
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
213+
214+
async def test_priority_queue_staleness_control_sample_strategy(self):
215+
max_staleness = 2
216+
self.config.algorithm.sample_strategy = "staleness_control"
217+
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
218+
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
219+
name="priority_queue_staleness_control",
220+
storage_type=StorageType.QUEUE.value,
221+
replay_buffer=ReplayBufferConfig(enable=True),
222+
)
223+
self.config.check_and_update()
224+
225+
# init testing data
226+
exps_list = self._default_exp_list()
227+
steps = self._default_steps()
228+
expected_model_versions_map = self._simulate_priority_queue(steps, max_staleness)
229+
230+
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
231+
232+
async def test_sql_staleness_control_sample_strategy(self):
233+
max_staleness = 2
234+
self.config.algorithm.sample_strategy = "staleness_control"
235+
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
236+
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
237+
name="sql_staleness_control",
238+
storage_type=StorageType.SQL.value,
239+
)
240+
self.config.check_and_update()
241+
242+
# init testing data
243+
exps_list = self._default_exp_list()
244+
steps = self._default_steps()
245+
246+
await self._flexible_verify_sampling_model_versions(exps_list, steps, max_staleness)
247+
248+
def tearDown(self):
249+
asyncio.run(self.buffer_writer.release())
250+
shutil.rmtree(self.config.checkpoint_job_dir)
251+
return super().tearDown()

tests/buffer/sql_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ async def test_sql_exp_buffer_read_write(self) -> None:
3434
prompt_length=i,
3535
reward=float(i),
3636
logprobs=torch.tensor([0.1]),
37+
info={"model_version": i},
3738
)
3839
for i in range(1, put_batch_size + 1)
3940
]
@@ -52,6 +53,7 @@ async def test_sql_exp_buffer_read_write(self) -> None:
5253
reward=float(i),
5354
logprobs=torch.tensor([0.1]),
5455
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
56+
info={"model_version": i},
5557
)
5658
for i in range(1, put_batch_size * 2 + 1)
5759
]

tests/common/experience_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def test_experience_model_experience_conversion(self):
253253
reward=reward,
254254
prompt_length=prompt_length,
255255
logprobs=logprobs,
256+
info={"model_version": 0},
256257
)
257258

258259
model = ExperienceModel.from_experience(experience)

tests/explorer/explorer_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import random
77
import shutil
8+
import unittest
89
from datetime import datetime
910

1011
import httpx
@@ -200,6 +201,7 @@ def setUp(self):
200201
if multiprocessing.get_start_method(allow_none=True) != "spawn":
201202
multiprocessing.set_start_method("spawn", force=True)
202203

204+
@unittest.skip("Require improvement for agent mode")
203205
async def test_serve(self): # noqa: C901
204206
serve_process = multiprocessing.Process(target=run_serve, args=(self.config,))
205207
serve_process.start()

tests/service/data_juicer_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,24 +182,28 @@ async def test_data_juicer_operators(self):
182182
prompt_length=3,
183183
prompt_text="Hello, how are you?",
184184
response_text="Hi, I am fine.",
185+
info={"model_version": 0},
185186
),
186187
Experience( # too short response
187188
tokens=torch.tensor([1, 2, 3, 4, 5]),
188189
prompt_length=3,
189190
prompt_text="What is your name?",
190191
response_text="Trinity.",
192+
info={"model_version": 0},
191193
),
192194
Experience( # repeated words
193195
tokens=torch.tensor([1, 2, 3, 4, 5]),
194196
prompt_length=3,
195197
prompt_text="What day is it today?",
196198
response_text="Today is Sunday Sunday Sunday Sunday Sunday and it's a happy day!",
199+
info={"model_version": 0},
197200
),
198201
Experience(
199202
tokens=torch.tensor([1, 2, 3, 4, 5]),
200203
prompt_length=3,
201204
prompt_text="What is your favorite color?",
202205
response_text="My favorite color is blue.",
206+
info={"model_version": 0},
203207
),
204208
]
205209
metrics = await pipeline.process.remote(exps)

trinity/algorithm/sample_strategy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
default_mapping={
77
"default": "trinity.algorithm.sample_strategy.sample_strategy.DefaultSampleStrategy",
88
"warmup": "trinity.algorithm.sample_strategy.sample_strategy.WarmupSampleStrategy",
9+
"staleness_control": "trinity.algorithm.sample_strategy.sample_strategy.StalenessControlSampleStrategy",
910
"mix": "trinity.algorithm.sample_strategy.mix_sample_strategy.MixSampleStrategy",
1011
},
1112
)

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,23 @@ def load_state_dict(self, state_dict: dict) -> None:
7676
self.exp_buffer.load_state_dict(state_dict)
7777

7878

79+
class StalenessControlSampleStrategy(DefaultSampleStrategy):
80+
def __init__(self, buffer_config: BufferConfig, **kwargs):
81+
super().__init__(buffer_config)
82+
self.max_staleness = kwargs.get("max_staleness", float("inf"))
83+
84+
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
85+
min_model_version = max(step - self.max_staleness, 0)
86+
metrics = {}
87+
with Timer(metrics, "time/read_experience"):
88+
exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version)
89+
repr_samples = representative_sample(exp_list)
90+
self.set_model_version_metric(exp_list, metrics)
91+
with Timer(metrics, "time/gather_experience"):
92+
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
93+
return exps, metrics, repr_samples
94+
95+
7996
@Deprecated
8097
class WarmupSampleStrategy(DefaultSampleStrategy):
8198
"""The warmup sample strategy.

0 commit comments

Comments
 (0)