@@ -41,6 +41,54 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
4141 @parameterized .expand (
4242 [
4343 (
44+ {"batch_size" : 5 , "total_steps" : 3 },
45+ {"selector_type" : "sequential" },
46+ [
47+ {"index" : 0 , "taskset_id" : 1 },
48+ {"index" : 1 , "taskset_id" : 1 },
49+ {"index" : 2 , "taskset_id" : 1 },
50+ {"index" : 0 , "taskset_id" : 0 },
51+ {"index" : 1 , "taskset_id" : 0 },
52+ {"index" : 3 , "taskset_id" : 1 },
53+ {"index" : 4 , "taskset_id" : 1 },
54+ {"index" : 5 , "taskset_id" : 1 },
55+ {"index" : 2 , "taskset_id" : 0 },
56+ {"index" : 3 , "taskset_id" : 0 },
57+ {"index" : 6 , "taskset_id" : 1 },
58+ {"index" : 0 , "taskset_id" : 1 },
59+ {"index" : 1 , "taskset_id" : 1 },
60+ {"index" : 4 , "taskset_id" : 0 },
61+ {"index" : 0 , "taskset_id" : 0 },
62+ ],
63+ ),
64+ (
65+ {"batch_size" : 5 , "total_epochs" : 2 },
66+ {"selector_type" : "sequential" },
67+ [
68+ {"index" : 0 , "taskset_id" : 1 },
69+ {"index" : 1 , "taskset_id" : 1 },
70+ {"index" : 2 , "taskset_id" : 1 },
71+ {"index" : 0 , "taskset_id" : 0 },
72+ {"index" : 1 , "taskset_id" : 0 },
73+ {"index" : 3 , "taskset_id" : 1 },
74+ {"index" : 4 , "taskset_id" : 1 },
75+ {"index" : 5 , "taskset_id" : 1 },
76+ {"index" : 2 , "taskset_id" : 0 },
77+ {"index" : 3 , "taskset_id" : 0 },
78+ {"index" : 6 , "taskset_id" : 1 },
79+ {"index" : 0 , "taskset_id" : 1 },
80+ {"index" : 1 , "taskset_id" : 1 },
81+ {"index" : 4 , "taskset_id" : 0 },
82+ {"index" : 0 , "taskset_id" : 0 },
83+ {"index" : 2 , "taskset_id" : 1 },
84+ {"index" : 3 , "taskset_id" : 1 },
85+ {"index" : 4 , "taskset_id" : 1 },
86+ {"index" : 1 , "taskset_id" : 0 },
87+ {"index" : 2 , "taskset_id" : 0 },
88+ ],
89+ ),
90+ (
91+ {"batch_size" : 2 , "total_epochs" : 2 },
4492 {"selector_type" : "sequential" },
4593 [
4694 {"index" : 0 , "taskset_id" : 1 },
@@ -70,6 +118,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
70118 ],
71119 ),
72120 (
121+ {"batch_size" : 2 , "total_epochs" : 2 },
73122 {"selector_type" : "shuffle" , "seed" : 42 },
74123 [
75124 {"index" : 3 , "taskset_id" : 1 },
@@ -99,6 +148,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
99148 ],
100149 ),
101150 (
151+ {"batch_size" : 2 , "total_epochs" : 2 },
102152 {"selector_type" : "random" , "seed" : 42 },
103153 [
104154 {"index" : 0 , "taskset_id" : 1 },
@@ -128,6 +178,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
128178 ],
129179 ),
130180 (
181+ {"batch_size" : 2 , "total_epochs" : 2 },
131182 {"selector_type" : "offline_easy2hard" , "feature_keys" : ["feature_offline" ]},
132183 [
133184 {"index" : 3 , "taskset_id" : 1 },
@@ -157,6 +208,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
157208 ],
158209 ),
159210 (
211+ {"batch_size" : 2 , "total_epochs" : 2 },
160212 {"selector_type" : "difficulty_based" , "feature_keys" : ["feat_1" , "feat_2" ]},
161213 [
162214 {"index" : 3 , "taskset_id" : 1 },
@@ -187,10 +239,13 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in
187239 ),
188240 ]
189241 )
190- async def test_task_scheduler (self , task_selector_kwargs , batch_tasks_orders ) -> None :
242+ async def test_task_scheduler (
243+ self , buffer_config_kwargs , task_selector_kwargs , batch_tasks_orders
244+ ) -> None :
191245 config = get_template_config ()
192- config .buffer .batch_size = 2
193- config .buffer .total_epochs = 2
246+ config .mode = "explore"
247+ for key , value in buffer_config_kwargs .items ():
248+ setattr (config .buffer , key , value )
194249 config .buffer .explorer_input .taskset = None
195250 config .buffer .explorer_input .tasksets = [
196251 TasksetConfig (
0 commit comments