|
11 | 11 | group_scores, |
12 | 12 | parse_config, |
13 | 13 | ) |
| 14 | +from trinity.utils.log import get_logger |
14 | 15 |
|
15 | 16 |
|
16 | 17 | def extract_metrics(dataset: Dataset) -> Dict: |
@@ -39,6 +40,12 @@ def __init__(self, config: DJConfig): |
39 | 40 | "usage_frequency": -0.5, |
40 | 41 | "quality": 1.0, |
41 | 42 | } |
| 43 | + self.order_method = self.config.order_method |
| 44 | + self.order_args = self.config.order_args or { |
| 45 | + "folding_layers": 3, |
| 46 | + } |
| 47 | + |
| 48 | + self.logger = get_logger(__name__) |
42 | 49 |
|
43 | 50 | def process_experience(self, ds: Dataset) -> Tuple[Dataset, Dict]: |
44 | 51 | """Process a batch of experiences. |
@@ -74,11 +81,45 @@ def process_task(self) -> Dict: |
74 | 81 | ) |
75 | 82 | ds = ds.map(compute_priority_scores_func) |
76 | 83 | # sort the output dataset in priority |
77 | | - if "priority" in ds.features: |
78 | | - top_k = self.config.top_k |
79 | | - if top_k == -1: |
80 | | - top_k = ds.num_rows |
81 | | - ds = ds.sort("priority", reverse=True).take(top_k) |
| 84 | + ds = self.order_task(ds) |
82 | 85 | # export to the target directory |
83 | 86 | ds.to_json(os.path.join(self.config.output_dir, "output.jsonl")) # type: ignore [arg-type] |
84 | 87 | return {"sample_num": ds.num_rows} |
| 88 | + |
| 89 | + def order_task(self, dataset: Dataset) -> Dataset: |
| 90 | + """ |
| 91 | + Order the dataset with specified method. |
| 92 | + """ |
| 93 | + # check if priority field exists |
| 94 | + if "priority" not in dataset.features and self.order_method in {"sort", "folding"}: |
| 95 | + self.logger.warning( |
| 96 | + f'"priority" field not found for {self.order_method}. Use "keep" instead.' |
| 97 | + ) |
| 98 | + self.order_method = "keep" |
| 99 | + |
| 100 | + # get top-k |
| 101 | + top_k = self.config.top_k |
| 102 | + if top_k == -1: |
| 103 | + top_k = dataset.num_rows |
| 104 | + |
| 105 | + if self.order_method == "keep": |
| 106 | + # keep the original order |
| 107 | + return dataset |
| 108 | + elif self.order_method == "shuffle": |
| 109 | + # shuffle the dataset |
| 110 | + return dataset.shuffle() |
| 111 | + elif self.order_method == "sort": |
| 112 | + # sort the dataset acording to priority |
| 113 | + return dataset.sort("priority", reverse=True).take(top_k) |
| 114 | + elif self.order_method == "folding": |
| 115 | + # folding the dataset to repeat the curriculum learning |
| 116 | + # Reference: https://arxiv.org/abs/2506.21545 |
| 117 | + sorted_dataset = dataset.sort("priority", reverse=True).take(top_k) |
| 118 | + folding_layers = self.order_args.get("folding_layers", 3) |
| 119 | + folding_indices = [] |
| 120 | + for j in range(folding_layers): |
| 121 | + partition = list(range(j, dataset.num_rows, folding_layers)) |
| 122 | + folding_indices.extend(partition) |
| 123 | + return sorted_dataset.select(folding_indices) |
| 124 | + else: |
| 125 | + raise ValueError(f"Invalid order method: {self.order_method}") |
0 commit comments