@@ -159,8 +159,9 @@ def run(self):
159159 traceback .print_exc ()
160160 return 7 , "Tracking lineage failed."
161161
162- # step 8. export the result to the output buffer
162+ # step 8. sort and export the result to the output buffer
163163 try :
164+ res_dataset .sort_by ("priority" , reverse = True )
164165 res_dataset .write_to_buffer ()
165166 except Exception :
166167 traceback .print_exc ()
@@ -246,7 +247,7 @@ def _compute_combined_score(
246247 difficulty = stats .get ("difficulty_score" , 0.5 )
247248 score += self .priority_weights ["difficulty" ] * difficulty
248249
249- sample ["priority" ] = [ score ]
250+ sample ["priority" ] = score
250251 return sample
251252
252253 def _compute_diversity_score (self ) -> float :
@@ -258,10 +259,6 @@ def _compute_priority_scores(self, dataset: RftDataset) -> RftDataset:
258259 dataset .data = dataset .data .map (self ._compute_combined_score )
259260 return dataset
260261
261- def _select_top_k (self , dataset : RftDataset , k : int ) -> List :
262- """Select top-k samples based on utility scores"""
263- return dataset .data .sort ("priority" , reverse = True ).take (k ).to_list ()
264-
265262 @ray .method (num_returns = 1 )
266263 def select_batch (self , dataset : RftDataset , batch_size : int ) -> List [Dict [str , Any ]]:
267264 """Select a batch of samples for training"""
@@ -273,7 +270,8 @@ def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, A
273270 dataset .data = dataset .data .filter (lambda s : s ["priority" ] >= self .min_priority_score )
274271
275272 # Select top-k samples
276- selected_samples = self ._select_top_k (dataset , batch_size )
273+ dataset .sort_by ("priority" , reverse = True , top_k = batch_size )
274+ selected_samples = dataset .data .to_list ()
277275
278276 # Update state
279277 self ._update_state (selected_samples , dataset .data ["priority" ])
0 commit comments