Skip to content

Commit 2c7fd58

Browse files
committed
* sort by priority before writing to buffer
1 parent a63b5f4 commit 2c7fd58

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

trinity/data/controllers/active_iterator.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ def run(self):
159159
traceback.print_exc()
160160
return 7, "Tracking lineage failed."
161161

162-
# step 8. export the result to the output buffer
162+
# step 8. sort and export the result to the output buffer
163163
try:
164+
res_dataset.sort_by("priority", reverse=True)
164165
res_dataset.write_to_buffer()
165166
except Exception:
166167
traceback.print_exc()
@@ -246,7 +247,7 @@ def _compute_combined_score(
246247
difficulty = stats.get("difficulty_score", 0.5)
247248
score += self.priority_weights["difficulty"] * difficulty
248249

249-
sample["priority"] = [score]
250+
sample["priority"] = score
250251
return sample
251252

252253
def _compute_diversity_score(self) -> float:
@@ -258,10 +259,6 @@ def _compute_priority_scores(self, dataset: RftDataset) -> RftDataset:
258259
dataset.data = dataset.data.map(self._compute_combined_score)
259260
return dataset
260261

261-
def _select_top_k(self, dataset: RftDataset, k: int) -> List:
262-
"""Select top-k samples based on utility scores"""
263-
return dataset.data.sort("priority", reverse=True).take(k).to_list()
264-
265262
@ray.method(num_returns=1)
266263
def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, Any]]:
267264
"""Select a batch of samples for training"""
@@ -273,7 +270,8 @@ def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, A
273270
dataset.data = dataset.data.filter(lambda s: s["priority"] >= self.min_priority_score)
274271

275272
# Select top-k samples
276-
selected_samples = self._select_top_k(dataset, batch_size)
273+
dataset.sort_by("priority", reverse=True, top_k=batch_size)
274+
selected_samples = dataset.data.to_list()
277275

278276
# Update state
279277
self._update_state(selected_samples, dataset.data["priority"])

trinity/data/core/dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ def format(
6464
for formatter in formatters:
6565
self.data = formatter(self.data, num_proc)
6666

67+
def sort_by(self, key: str, reverse: bool = False, top_k: int = -1):
68+
if top_k == -1:
69+
top_k = len(self.data)
70+
self.data = self.data.sort(key, reverse=reverse).take(top_k)
71+
6772
def read_from_buffer(self):
6873
datasets = []
6974
for buffer in self.buffers:

0 commit comments

Comments
 (0)