@@ -93,14 +93,14 @@ def thread_read(reader, result_queue):
9393 self .assertRaises (StopIteration , reader .read , batch_size = 1 )
9494
9595 async def test_priority_queue_capacity (self ):
96- # test queue capacity
96+ # test priority queue capacity
9797 self .config .train_batch_size = 4
9898 meta = StorageConfig (
9999 name = "test_buffer_small" ,
100100 schema_type = "experience" ,
101101 storage_type = StorageType .QUEUE ,
102102 max_read_timeout = 1 ,
103- capacity = 100 , # priority will use 2 * train_batch_size as capacity (8)
103+ capacity = 8 ,
104104 path = BUFFER_FILE_PATH ,
105105 use_priority_queue = True ,
106106 replay_buffer_kwargs = {"priority_fn" : "linear_decay" , "decay" : 0.6 },
@@ -177,13 +177,13 @@ def write_blocking_call():
177177 self .assertFalse (thread .is_alive ())
178178
179179 async def test_priority_queue_buffer_reuse (self ):
180- # test queue reuse
180+ # test experience replay
181181 meta = StorageConfig (
182182 name = "test_buffer_small" ,
183183 schema_type = "experience" ,
184184 storage_type = StorageType .QUEUE ,
185185 max_read_timeout = 3 ,
186- capacity = 4 ,
186+ capacity = 4 , # max total number of items; each item is List[Experience]
187187 path = BUFFER_FILE_PATH ,
188188 use_priority_queue = True ,
189189 reuse_cooldown_time = 0.5 ,
@@ -300,6 +300,109 @@ def replace_call():
300300 # use_count 5, 4, 2, 1
301301 # priority 1.0, 0.6, 0.8, 0.4
302302
303+ async def test_priority_queue_reuse_count_control (self ):
304+ # test experience replay with linear decay and use count control
305+ meta = StorageConfig (
306+ name = "test_buffer_small" ,
307+ schema_type = "experience" ,
308+ storage_type = StorageType .QUEUE ,
309+ max_read_timeout = 3 ,
310+ capacity = 4 , # max total number of items; each item is List[Experience]
311+ path = BUFFER_FILE_PATH ,
312+ use_priority_queue = True ,
313+ reuse_cooldown_time = 0.5 ,
314+ replay_buffer_kwargs = {
315+ "priority_fn" : "linear_decay_use_count_control_randomization" ,
316+ "decay" : 1.2 ,
317+ "use_count_limit" : 2 ,
318+ "sigma" : 0.0 ,
319+ },
320+ )
321+ writer = QueueWriter (meta , self .config )
322+ reader = QueueReader (meta , self .config )
323+ for i in range (4 ):
324+ writer .write (
325+ [
326+ Experience (
327+ tokens = torch .tensor ([1 , 2 , 3 ]),
328+ prompt_length = 2 ,
329+ info = {"model_version" : i , "use_count" : 0 },
330+ ),
331+ Experience (
332+ tokens = torch .tensor ([1 , 2 , 3 ]),
333+ prompt_length = 2 ,
334+ info = {"model_version" : i , "use_count" : 0 },
335+ ),
336+ ]
337+ )
338+
339+ # should not be blocked
340+ def replace_call ():
341+ writer .write (
342+ [
343+ Experience (
344+ tokens = torch .tensor ([1 , 2 , 3 ]),
345+ prompt_length = 2 ,
346+ info = {"model_version" : 4 , "use_count" : 0 },
347+ ),
348+ Experience (
349+ tokens = torch .tensor ([1 , 2 , 3 ]),
350+ prompt_length = 2 ,
351+ info = {"model_version" : 4 , "use_count" : 0 },
352+ ),
353+ ]
354+ )
355+
356+ thread = threading .Thread (target = replace_call )
357+ thread .start ()
358+ thread .join (timeout = 2 )
359+ self .assertFalse (thread .is_alive ())
360+
361+ exps = reader .read (batch_size = 4 )
362+ self .assertEqual (len (exps ), 4 )
363+ self .assertEqual (exps [0 ].info ["model_version" ], 4 )
364+ self .assertEqual (exps [0 ].info ["use_count" ], 1 )
365+ self .assertEqual (exps [2 ].info ["model_version" ], 3 )
366+ self .assertEqual (exps [2 ].info ["use_count" ], 1 )
367+
368+ # model_version 4, 3, 2, 1
369+ # use_count 1, 1, 0, 0
370+ # priority 2.8, 1.8, 2.0, 1.0
371+ # in queue Y, Y, Y, Y
372+
373+ time .sleep (1 )
374+ self .assertEqual (ray .get (reader .queue .length .remote ()), 4 )
375+ exps = reader .read (batch_size = 4 )
376+ self .assertEqual (len (exps ), 4 )
377+ self .assertEqual (exps [0 ].info ["model_version" ], 4 )
378+ self .assertEqual (exps [0 ].info ["use_count" ], 2 )
379+ self .assertEqual (exps [2 ].info ["model_version" ], 2 )
380+ self .assertEqual (exps [2 ].info ["use_count" ], 1 )
381+
382+ # model_version 4, 3, 2, 1
383+ # use_count 2, 1, 1, 0
384+ # priority 1.6, 1.8, 0.8, 1.0
385+ # in queue N, Y, Y, Y
386+ # model_version = 4 item is discarded for reaching use_count_limit
387+
388+ time .sleep (1 )
389+ self .assertEqual (ray .get (reader .queue .length .remote ()), 3 )
390+ exps = reader .read (batch_size = 4 )
391+ self .assertEqual (len (exps ), 4 )
392+ self .assertEqual (exps [0 ].info ["model_version" ], 3 )
393+ self .assertEqual (exps [0 ].info ["use_count" ], 2 )
394+ self .assertEqual (exps [2 ].info ["model_version" ], 1 )
395+ self .assertEqual (exps [2 ].info ["use_count" ], 1 )
396+
397+ # model_version 3, 2, 1
398+ # use_count 2, 1, 1
399+ # priority 0.6, 0.8, -0.2
400+ # in queue N, Y, Y
401+ # model_version = 3 item is discarded for reaching use_count_limit
402+
403+ time .sleep (1 )
404+ self .assertEqual (ray .get (reader .queue .length .remote ()), 2 )
405+
303406 def setUp (self ):
304407 self .total_num = 8
305408 self .put_batch_size = 2
0 commit comments