22"""The taskset scheduler."""
33
44from collections import Counter
5+ from copy import deepcopy
56from typing import Dict , List
67
78import numpy as np
89
910from trinity .buffer .buffer import get_buffer_reader
10- from trinity .buffer .selector import SELECTORS
1111from trinity .common .config import Config
1212from trinity .common .constants import SELECTOR_METRIC
1313from trinity .utils .annotations import Experimental
@@ -47,7 +47,7 @@ def state_dict(self) -> List[Dict]:
4747 """
4848 raise NotImplementedError
4949
50- def update (self , pipeline_metrics : Dict ) -> None :
50+ def feedback (self , pipeline_metrics : Dict ) -> None :
5151 """Update selectors using feedback from the training pipeline."""
5252 raise NotImplementedError
5353
@@ -68,16 +68,18 @@ def __init__(self, explorer_state: Dict, config: Config):
6868 index = self .explorer_state .get ("taskset_states" , [{"current_index" : 0 }])[0 ].get (
6969 "current_index" , 0
7070 )
71- self .config .buffer .explorer_input .tasksets [0 ].index = index
72- self .reader = get_buffer_reader (config .buffer .explorer_input .tasksets [0 ])
71+ taskset_config = deepcopy (self .config .buffer .explorer_input .tasksets [0 ])
72+ taskset_config .index = index
73+ taskset_config .task_selector = None # disable selection
74+ self .reader = get_buffer_reader (taskset_config )
7375
7476 async def read_async (self ) -> List :
7577 return await self .reader .read_async ()
7678
7779 def state_dict (self ) -> List [Dict ]:
7880 return [self .reader .state_dict ()]
7981
80- def update (self , pipeline_metrics : Dict ) -> None :
82+ def feedback (self , pipeline_metrics : Dict ) -> None :
8183 # do nothing here
8284 return
8385
@@ -127,7 +129,6 @@ def __init__(self, explorer_state: Dict, config: Config):
127129 "taskset_states" , [{"current_index" : 0 }] * len (taskset_configs )
128130 )
129131 self .tasksets = []
130- self .selectors = []
131132 for taskset_config , taskset_state in zip (taskset_configs , taskset_states ):
132133 assert not taskset_config .is_eval # assume drop last
133134 taskset = get_buffer_reader (taskset_config )
@@ -136,15 +137,8 @@ def __init__(self, explorer_state: Dict, config: Config):
136137 f"Taskset '{ taskset_config .name } ' has an unsupported type '{ type (taskset ).__name__ } '."
137138 f"Currently, only 'FileReader' is supported by TasksetScheduler."
138139 )
139-
140- # Create selector based on type specified in config (e.g., 'sequential', 'shuffle')
141- selector = SELECTORS .get (taskset_config .task_selector .selector_type )(
142- taskset .reader .dataset , taskset_config .task_selector
143- )
144- selector .load_state_dict (taskset_state ) # Restore any prior state
145-
140+ taskset .load_state_dict (taskset_state ) # Restore any prior state
146141 self .tasksets .append (taskset )
147- self .selectors .append (selector )
148142
149143 # Each explorer step calls read_async once → track step globally
150144 self .step = explorer_state .get ("latest_iteration" , 0 )
@@ -224,8 +218,7 @@ async def read_async(self) -> List:
224218 counter = Counter (taskset_ids )
225219 batch = []
226220 for taskset_id , count in counter .items ():
227- indices = self .selectors [taskset_id ].get_indices (batch_size = count )
228- tasks = await self .tasksets [taskset_id ].read_with_indices_async (indices )
221+ tasks = await self .tasksets [taskset_id ].read_async (batch_size = count )
229222 # Annotate each task with its origin
230223 for task in tasks :
231224 task .index ["taskset_id" ] = taskset_id
@@ -239,13 +232,13 @@ def state_dict(self) -> List[Dict]:
239232 Save persistent state for checkpointing.
240233
241234 Returns:
242- List[Dict]: State dicts for all selectors (one per taskset)
235+ List[Dict]: State dicts for all tasksets
243236 """
244- return [selector .state_dict () for selector in self .selectors ]
237+ return [taskset .state_dict () for taskset in self .tasksets ]
245238
246- def update (self , pipeline_metrics : Dict ) -> None :
239+ def feedback (self , pipeline_metrics : Dict ) -> None :
247240 """
248- Update selectors using feedback from the training pipeline.
241+ Update selectors in tasksets using feedback from the training pipeline.
249242
250243 Expected format:
251244 pipeline_metrics = {
@@ -265,5 +258,5 @@ def update(self, pipeline_metrics: Dict) -> None:
265258 return
266259 selector_metric = pipeline_metrics .pop (SELECTOR_METRIC , {})
267260 for taskset_id , taskset_kwargs in selector_metric .items ():
268- selector = self .selectors [taskset_id ]
269- selector . update (** taskset_kwargs )
261+ taskset = self .tasksets [taskset_id ]
262+ taskset . feedback (** taskset_kwargs )
0 commit comments