Skip to content

Commit 3d12bd9

Browse files
authored
Enhance experience replay for priority queue buffer (agentscope-ai#306)
1 parent d5db95a commit 3d12bd9

File tree

3 files changed

+162
-13
lines changed

3 files changed

+162
-13
lines changed

tests/buffer/queue_test.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ def thread_read(reader, result_queue):
9393
self.assertRaises(StopIteration, reader.read, batch_size=1)
9494

9595
async def test_priority_queue_capacity(self):
96-
# test queue capacity
96+
# test priority queue capacity
9797
self.config.train_batch_size = 4
9898
meta = StorageConfig(
9999
name="test_buffer_small",
100100
schema_type="experience",
101101
storage_type=StorageType.QUEUE,
102102
max_read_timeout=1,
103-
capacity=100, # priority will use 2 * train_batch_size as capacity (8)
103+
capacity=8,
104104
path=BUFFER_FILE_PATH,
105105
use_priority_queue=True,
106106
replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6},
@@ -177,13 +177,13 @@ def write_blocking_call():
177177
self.assertFalse(thread.is_alive())
178178

179179
async def test_priority_queue_buffer_reuse(self):
180-
# test queue reuse
180+
# test experience replay
181181
meta = StorageConfig(
182182
name="test_buffer_small",
183183
schema_type="experience",
184184
storage_type=StorageType.QUEUE,
185185
max_read_timeout=3,
186-
capacity=4,
186+
capacity=4, # max total number of items; each item is List[Experience]
187187
path=BUFFER_FILE_PATH,
188188
use_priority_queue=True,
189189
reuse_cooldown_time=0.5,
@@ -300,6 +300,109 @@ def replace_call():
300300
# use_count 5, 4, 2, 1
301301
# priority 1.0, 0.6, 0.8, 0.4
302302

303+
async def test_priority_queue_reuse_count_control(self):
304+
# test experience replay with linear decay and use count control
305+
meta = StorageConfig(
306+
name="test_buffer_small",
307+
schema_type="experience",
308+
storage_type=StorageType.QUEUE,
309+
max_read_timeout=3,
310+
capacity=4, # max total number of items; each item is List[Experience]
311+
path=BUFFER_FILE_PATH,
312+
use_priority_queue=True,
313+
reuse_cooldown_time=0.5,
314+
replay_buffer_kwargs={
315+
"priority_fn": "linear_decay_use_count_control_randomization",
316+
"decay": 1.2,
317+
"use_count_limit": 2,
318+
"sigma": 0.0,
319+
},
320+
)
321+
writer = QueueWriter(meta, self.config)
322+
reader = QueueReader(meta, self.config)
323+
for i in range(4):
324+
writer.write(
325+
[
326+
Experience(
327+
tokens=torch.tensor([1, 2, 3]),
328+
prompt_length=2,
329+
info={"model_version": i, "use_count": 0},
330+
),
331+
Experience(
332+
tokens=torch.tensor([1, 2, 3]),
333+
prompt_length=2,
334+
info={"model_version": i, "use_count": 0},
335+
),
336+
]
337+
)
338+
339+
# should not be blocked
340+
def replace_call():
341+
writer.write(
342+
[
343+
Experience(
344+
tokens=torch.tensor([1, 2, 3]),
345+
prompt_length=2,
346+
info={"model_version": 4, "use_count": 0},
347+
),
348+
Experience(
349+
tokens=torch.tensor([1, 2, 3]),
350+
prompt_length=2,
351+
info={"model_version": 4, "use_count": 0},
352+
),
353+
]
354+
)
355+
356+
thread = threading.Thread(target=replace_call)
357+
thread.start()
358+
thread.join(timeout=2)
359+
self.assertFalse(thread.is_alive())
360+
361+
exps = reader.read(batch_size=4)
362+
self.assertEqual(len(exps), 4)
363+
self.assertEqual(exps[0].info["model_version"], 4)
364+
self.assertEqual(exps[0].info["use_count"], 1)
365+
self.assertEqual(exps[2].info["model_version"], 3)
366+
self.assertEqual(exps[2].info["use_count"], 1)
367+
368+
# model_version 4, 3, 2, 1
369+
# use_count 1, 1, 0, 0
370+
# priority 2.8, 1.8, 2.0, 1.0
371+
# in queue Y, Y, Y, Y
372+
373+
time.sleep(1)
374+
self.assertEqual(ray.get(reader.queue.length.remote()), 4)
375+
exps = reader.read(batch_size=4)
376+
self.assertEqual(len(exps), 4)
377+
self.assertEqual(exps[0].info["model_version"], 4)
378+
self.assertEqual(exps[0].info["use_count"], 2)
379+
self.assertEqual(exps[2].info["model_version"], 2)
380+
self.assertEqual(exps[2].info["use_count"], 1)
381+
382+
# model_version 4, 3, 2, 1
383+
# use_count 2, 1, 1, 0
384+
# priority 1.6, 1.8, 0.8, 1.0
385+
# in queue N, Y, Y, Y
386+
# model_version = 4 item is discarded for reaching use_count_limit
387+
388+
time.sleep(1)
389+
self.assertEqual(ray.get(reader.queue.length.remote()), 3)
390+
exps = reader.read(batch_size=4)
391+
self.assertEqual(len(exps), 4)
392+
self.assertEqual(exps[0].info["model_version"], 3)
393+
self.assertEqual(exps[0].info["use_count"], 2)
394+
self.assertEqual(exps[2].info["model_version"], 1)
395+
self.assertEqual(exps[2].info["use_count"], 1)
396+
397+
# model_version 3, 2, 1
398+
# use_count 2, 1, 1
399+
# priority 0.6, 0.8, -0.2
400+
# in queue N, Y, Y
401+
# model_version = 3 item is discarded for reaching use_count_limit
402+
403+
time.sleep(1)
404+
self.assertEqual(ray.get(reader.queue.length.remote()), 2)
405+
303406
def setUp(self):
304407
self.total_num = 8
305408
self.put_batch_size = 2

trinity/buffer/storage/queue.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from collections import deque
66
from copy import deepcopy
77
from functools import partial
8-
from typing import List, Optional
8+
from typing import List, Optional, Tuple
99

10+
import numpy as np
1011
import ray
1112
from sortedcontainers import SortedDict
1213

@@ -26,11 +27,48 @@ def is_json_file(path: str) -> bool:
2627

2728

2829
PRIORITY_FUNC = Registry("priority_fn")
30+
"""
31+
Each priority_fn,
32+
Args:
33+
item: List[Experience], assume that all experiences in it have the same model_version and use_count
34+
kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
35+
Returns:
36+
priority: float
37+
put_into_queue: bool, decide whether to put item into queue
38+
Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
39+
"""
2940

3041

3142
@PRIORITY_FUNC.register_module("linear_decay")
32-
def linear_decay_priority(item: List[Experience], decay: float = 0.1):
33-
return item[0].info["model_version"] - decay * item[0].info["use_count"] # type: ignore
43+
def linear_decay_priority(
44+
item: List[Experience],
45+
decay: float = 2.0,
46+
) -> Tuple[float, bool]:
47+
"""Calculate priority by linear decay.
48+
49+
Priority is calculated as `model_version - decay * use_count. The item is always put back into the queue for reuse (as long as `reuse_cooldown_time` is not None).
50+
"""
51+
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
52+
put_into_queue = True
53+
return priority, put_into_queue
54+
55+
56+
@PRIORITY_FUNC.register_module("linear_decay_use_count_control_randomization")
57+
def linear_decay_use_count_control_priority(
58+
item: List[Experience],
59+
decay: float = 2.0,
60+
use_count_limit: int = 3,
61+
sigma: float = 0.0,
62+
) -> Tuple[float, bool]:
63+
"""Calculate priority by linear decay, use count control, and randomization.
64+
65+
Priority is calculated as `model_version - decay * use_count`; if `sigma` is non-zero, priority is further perturbed by random Gaussian noise with standard deviation `sigma`. The item will be put back into the queue only if use count does not exceed `use_count_limit`.
66+
"""
67+
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
68+
if sigma > 0.0:
69+
priority += float(np.random.randn() * sigma)
70+
put_into_queue = item[0].info["use_count"] < use_count_limit if use_count_limit > 0 else True
71+
return priority, put_into_queue
3472

3573

3674
class QueueBuffer(ABC):
@@ -61,7 +99,7 @@ def get_queue(cls, storage_config: StorageConfig, config: BufferConfig) -> "Queu
6199
if storage_config.use_priority_queue:
62100
reuse_cooldown_time = storage_config.reuse_cooldown_time
63101
replay_buffer_kwargs = storage_config.replay_buffer_kwargs
64-
capacity = min(storage_config.capacity, config.train_batch_size * 2)
102+
capacity = storage_config.capacity
65103
logger.info(
66104
f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {reuse_cooldown_time}."
67105
)
@@ -124,6 +162,7 @@ def __init__(
124162
kwargs: Additional keyword arguments for the priority function.
125163
"""
126164
self.capacity = capacity
165+
self.item_count = 0
127166
self.priority_groups = SortedDict() # Maps priority -> deque of items
128167
self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **kwargs)
129168
self.reuse_cooldown_time = reuse_cooldown_time
@@ -142,22 +181,28 @@ async def _put(self, item: List[Experience], delay: float = 0) -> None:
142181
await asyncio.sleep(delay)
143182
if len(item) == 0:
144183
return
145-
priority = self.priority_fn(item=item)
184+
185+
priority, put_into_queue = self.priority_fn(item=item)
186+
if not put_into_queue:
187+
return
188+
146189
async with self._condition:
147-
if len(self.priority_groups) == self.capacity:
190+
if self.item_count == self.capacity:
148191
# If full, only insert if new item has higher or equal priority than the lowest
149192
lowest_priority, item_queue = self.priority_groups.peekitem(index=0)
150193
if lowest_priority > priority:
151194
return # Skip insertion if lower priority
152195
# Remove the lowest priority item
153196
item_queue.popleft()
197+
self.item_count -= 1
154198
if not item_queue:
155199
self.priority_groups.popitem(index=0)
156200

157201
# Add the new item
158202
if priority not in self.priority_groups:
159203
self.priority_groups[priority] = deque()
160204
self.priority_groups[priority].append(item)
205+
self.item_count += 1
161206
self._condition.notify()
162207

163208
async def put(self, item: List[Experience]) -> None:
@@ -181,19 +226,20 @@ async def get(self) -> List[Experience]:
181226

182227
_, item_queue = self.priority_groups.peekitem(index=-1)
183228
item = item_queue.popleft()
229+
self.item_count -= 1
184230
if not item_queue:
185231
self.priority_groups.popitem(index=-1)
186232

187233
for exp in item:
188234
exp.info["use_count"] += 1
189235
# Optionally resubmit the item after a cooldown
190236
if self.reuse_cooldown_time is not None:
191-
asyncio.create_task(self._put(item, self.reuse_cooldown_time))
237+
asyncio.create_task(self._put(item, delay=self.reuse_cooldown_time))
192238

193239
return item
194240

195241
def qsize(self):
196-
return len(self.priority_groups)
242+
return self.item_count
197243

198244
async def close(self) -> None:
199245
"""

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class StorageConfig:
125125
use_priority_queue: bool = False
126126
reuse_cooldown_time: Optional[float] = None
127127
replay_buffer_kwargs: dict = field(
128-
default_factory=lambda: {"priority_fn": "linear_decay", "decay": 0.1}
128+
default_factory=lambda: {"priority_fn": "linear_decay", "decay": 2.0}
129129
)
130130

131131
# used for StorageType.SQL

0 commit comments

Comments
 (0)