88
99from tests .tools import get_template_config
1010from trinity .buffer .reader .queue_reader import QueueReader
11- from trinity .common .config import StorageConfig
11+ from trinity .common .config import GenerationConfig , StorageConfig
1212from trinity .common .constants import StorageType
1313from trinity .common .experience import Experience
1414from trinity .common .models .model import InferenceModel
@@ -23,6 +23,7 @@ def __init__(self, model, task, auxiliary_models):
2323 super ().__init__ (model , task , auxiliary_models )
2424 self .error_type = task .raw_task .get ("error_type" , "" )
2525 self .seconds = None
26+ self .repeat_times = task .rollout_args .n
2627 if "timeout" in self .error_type :
2728 parts = self .error_type .split ("_" )
2829 if len (parts ) > 1 :
@@ -42,8 +43,12 @@ def run(self) -> List[Experience]:
4243
4344 return [
4445 Experience (
45- tokens = torch .zeros (5 ), prompt_length = 2 , prompt_text = self .error_type or "success"
46+ tokens = torch .zeros (5 ),
47+ prompt_length = 2 ,
48+ prompt_text = self .error_type or "success" ,
49+ info = {"repeat_times" : self .repeat_times },
4650 )
51+ for _ in range (self .repeat_times )
4752 ]
4853
4954
@@ -98,7 +103,11 @@ def api_server_ready(self) -> Tuple[str, str]:
98103
99104
100105def generate_tasks (
101- total_num : int , timeout_num : int = 0 , exception_num : int = 0 , timeout_seconds : int = 10
106+ total_num : int ,
107+ timeout_num : int = 0 ,
108+ exception_num : int = 0 ,
109+ timeout_seconds : int = 10 ,
110+ repeat_times : int = 1 ,
102111):
103112 """Generate some tasks for testing
104113
@@ -108,7 +117,10 @@ def generate_tasks(
108117 exception_num: number of exception tasks
109118 timeout_seconds: the timeout for timeout tasks
110119 """
111- tasks = [Task (workflow = DummyWorkflow , raw_task = {}) for _ in range (total_num )]
120+ tasks = [
121+ Task (workflow = DummyWorkflow , raw_task = {}, rollout_args = GenerationConfig (n = repeat_times ))
122+ for _ in range (total_num )
123+ ]
112124
113125 tasks .extend (
114126 [
@@ -150,6 +162,9 @@ def setUp(self):
150162 algorithm_type = "ppo" ,
151163 path = "" ,
152164 )
165+ self .config .buffer .trainer_input .experience_buffer .max_read_timeout = 1
166+ self .config .algorithm .repeat_times = 1
167+ self .config .check_and_update ()
153168 self .queue = QueueReader (
154169 self .config .buffer .trainer_input .experience_buffer , self .config .buffer
155170 )
@@ -163,6 +178,9 @@ async def test_get_results(self):
163178
164179 results = await scheduler .get_results (batch_id = 0 , min_num = 8 , timeout = 20 )
165180 self .assertEqual (len (results ), 8 )
181+ self .assertEqual (len (self .queue .read (batch_size = 8 )), 8 )
182+ with self .assertRaises (TimeoutError ):
183+ self .queue .read (batch_size = 1 )
166184
167185 for result in results :
168186 self .assertTrue (result .ok )
@@ -176,13 +194,17 @@ async def test_get_results(self):
176194 results = await scheduler .get_results (batch_id = batch_id , min_num = 4 , timeout = 10 )
177195 self .assertEqual (len (results ), 4 )
178196 self .assertFalse (scheduler .has_step (batch_id ))
197+ self .assertEqual (len (self .queue .read (batch_size = 4 )), 4 )
198+ with self .assertRaises (TimeoutError ):
199+ self .queue .read (batch_size = 1 )
179200
180201 tasks = generate_tasks (3 )
181202 scheduler .schedule (tasks , batch_id = 4 )
182203 self .assertTrue (scheduler .has_step (4 ))
183204 results = await scheduler .get_results (batch_id = 4 )
184205 self .assertEqual (len (results ), 3 )
185206 self .assertFalse (scheduler .has_step (4 ))
207+ self .assertEqual (len (self .queue .read (batch_size = 3 )), 3 )
186208
187209 # test timeout
188210 tasks = generate_tasks (2 , timeout_num = 2 , timeout_seconds = 10 )
@@ -194,6 +216,7 @@ async def test_get_results(self):
194216
195217 self .assertLessEqual (end_time - start_time , 5 )
196218 self .assertEqual (len (results ), 2 )
219+ self .assertEqual (len (self .queue .read (batch_size = 2 )), 2 )
197220
198221 # test run tasks after timeout
199222 tasks = generate_tasks (4 )
@@ -204,8 +227,10 @@ async def test_get_results(self):
204227 self .assertEqual (len (results ), 4 )
205228
206229 success_count = sum (1 for r in results if r .ok )
207-
208- self .assertEqual (success_count , sum (1 for r in results if r .ok ))
230+ self .assertEqual (success_count , 4 )
231+ self .assertEqual (len (self .queue .read (batch_size = 4 )), 4 )
232+ with self .assertRaises (TimeoutError ):
233+ self .queue .read (batch_size = 1 )
209234
210235 # test exception tasks
211236 tasks = generate_tasks (1 , exception_num = 3 )
@@ -215,14 +240,21 @@ async def test_get_results(self):
215240
216241 success_count = sum (1 for r in results if r .ok )
217242 self .assertEqual (success_count , 1 )
243+ self .assertEqual (len (self .queue .read (batch_size = 1 )), 1 )
244+ with self .assertRaises (TimeoutError ):
245+ self .queue .read (batch_size = 1 )
218246
219247 # test clear_timeout_tasks
220248 tasks = generate_tasks (3 , timeout_num = 1 , timeout_seconds = 3 )
221249 scheduler .schedule (tasks , batch_id = 2 )
222250 results = await scheduler .get_results (batch_id = 2 , timeout = 2 , clear_timeout_tasks = False )
223251 self .assertEqual (len (results ), 3 )
252+ self .assertEqual (len (self .queue .read (batch_size = 3 )), 3 )
224253 results = await scheduler .get_results (batch_id = 2 , timeout = 2 , clear_timeout_tasks = False )
225254 self .assertEqual (len (results ), 1 )
255+ self .assertEqual (len (self .queue .read (batch_size = 1 )), 1 )
256+ with self .assertRaises (TimeoutError ):
257+ self .queue .read (batch_size = 1 )
226258
227259 await scheduler .stop ()
228260
@@ -366,6 +398,38 @@ async def test_scheduler_all_methods(self):
366398 self .assertFalse (scheduler .has_step (2 ))
367399 await scheduler .stop ()
368400
401+ async def test_split_tasks (self ):
402+ self .config .explorer .max_repeat_times_per_runner = 2
403+ self .config .check_and_update ()
404+ scheduler = Scheduler (self .config , [DummyModel .remote (), DummyModel .remote ()])
405+ await scheduler .start ()
406+
407+ tasks = generate_tasks (4 , repeat_times = 8 ) # ceil(8 / 2) == 4
408+ scheduler .schedule (tasks , batch_id = 1 )
409+ results = await scheduler .get_results (batch_id = 1 )
410+ self .assertEqual (len (results ), 4 * 4 )
411+ self .assertEqual (len (self .queue .read (batch_size = 4 * 8 )), 4 * 8 )
412+ with self .assertRaises (TimeoutError ):
413+ self .queue .read (batch_size = 1 )
414+
415+ tasks = generate_tasks (4 , repeat_times = 5 ) # ceil(5 / 2) == 3
416+ scheduler .schedule (tasks , batch_id = 1 )
417+ results = await scheduler .get_results (batch_id = 1 )
418+ self .assertEqual (len (results ), 4 * 3 )
419+ self .assertEqual (len (self .queue .read (batch_size = 4 * 5 )), 4 * 5 )
420+ with self .assertRaises (TimeoutError ):
421+ self .queue .read (batch_size = 1 )
422+
423+ tasks = generate_tasks (3 , repeat_times = 1 ) # ceil(1 / 2) == 1
424+ scheduler .schedule (tasks , batch_id = 1 )
425+ results = await scheduler .get_results (batch_id = 1 )
426+ self .assertEqual (len (results ), 3 * 1 )
427+ self .assertEqual (len (self .queue .read (batch_size = 3 * 1 )), 3 * 1 )
428+ with self .assertRaises (TimeoutError ):
429+ self .queue .read (batch_size = 1 )
430+
431+ await scheduler .stop ()
432+
369433 def tearDown (self ):
370434 try :
371435 ray .shutdown ()
0 commit comments