@@ -403,31 +403,44 @@ async def test_split_tasks(self):
403403 self .config .check_and_update ()
404404 scheduler = Scheduler (self .config , [DummyModel .remote (), DummyModel .remote ()])
405405 await scheduler .start ()
406+ exp_list = []
406407
407408 tasks = generate_tasks (4 , repeat_times = 8 ) # ceil(8 / 2) == 4
408409 scheduler .schedule (tasks , batch_id = 1 )
409410 results = await scheduler .get_results (batch_id = 1 )
410411 self .assertEqual (len (results ), 4 * 4 )
411- self .assertEqual (len (self .queue .read (batch_size = 4 * 8 )), 4 * 8 )
412+ exps = self .queue .read (batch_size = 4 * 8 )
413+ self .assertEqual (len (exps ), 4 * 8 )
414+ exp_list .extend (exps )
412415 with self .assertRaises (TimeoutError ):
413416 self .queue .read (batch_size = 1 )
414417
415418 tasks = generate_tasks (4 , repeat_times = 5 ) # ceil(5 / 2) == 3
416- scheduler .schedule (tasks , batch_id = 1 )
417- results = await scheduler .get_results (batch_id = 1 )
419+ scheduler .schedule (tasks , batch_id = 2 )
420+ results = await scheduler .get_results (batch_id = 2 )
418421 self .assertEqual (len (results ), 4 * 3 )
419- self .assertEqual (len (self .queue .read (batch_size = 4 * 5 )), 4 * 5 )
422+ exps = self .queue .read (batch_size = 4 * 5 )
423+ self .assertEqual (len (exps ), 4 * 5 )
424+ exp_list .extend (exps )
420425 with self .assertRaises (TimeoutError ):
421426 self .queue .read (batch_size = 1 )
422427
423428 tasks = generate_tasks (3 , repeat_times = 1 ) # ceil(1 / 2) == 1
424- scheduler .schedule (tasks , batch_id = 1 )
425- results = await scheduler .get_results (batch_id = 1 )
429+ scheduler .schedule (tasks , batch_id = 3 )
430+ results = await scheduler .get_results (batch_id = 3 )
426431 self .assertEqual (len (results ), 3 * 1 )
427- self .assertEqual (len (self .queue .read (batch_size = 3 * 1 )), 3 * 1 )
432+ exps = self .queue .read (batch_size = 3 * 1 )
433+ self .assertEqual (len (exps ), 3 * 1 )
434+ exp_list .extend (exps )
428435 with self .assertRaises (TimeoutError ):
429436 self .queue .read (batch_size = 1 )
430437
438+ # test group_id and unique_id
439+ group_ids = [exp .group_id for exp in exp_list ]
440+ self .assertEqual (len (set (group_ids )), 11 ) # 4 + 4 + 3
441+ unique_ids = [exp .unique_id for exp in exp_list ]
442+ self .assertEqual (len (unique_ids ), len (set (unique_ids )))
443+
431444 await scheduler .stop ()
432445
433446 def tearDown (self ):
0 commit comments