Skip to content

Commit 412497f

Browse files
authored
Support new task ordering methods (agentscope-ai#265)
1 parent 11e5763 commit 412497f

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

trinity/service/data_juicer/server/session.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
group_scores,
1212
parse_config,
1313
)
14+
from trinity.utils.log import get_logger
1415

1516

1617
def extract_metrics(dataset: Dataset) -> Dict:
@@ -39,6 +40,12 @@ def __init__(self, config: DJConfig):
3940
"usage_frequency": -0.5,
4041
"quality": 1.0,
4142
}
43+
self.order_method = self.config.order_method
44+
self.order_args = self.config.order_args or {
45+
"folding_layers": 3,
46+
}
47+
48+
self.logger = get_logger(__name__)
4249

4350
def process_experience(self, ds: Dataset) -> Tuple[Dataset, Dict]:
4451
"""Process a batch of experiences.
@@ -74,11 +81,45 @@ def process_task(self) -> Dict:
7481
)
7582
ds = ds.map(compute_priority_scores_func)
7683
# sort the output dataset in priority
77-
if "priority" in ds.features:
78-
top_k = self.config.top_k
79-
if top_k == -1:
80-
top_k = ds.num_rows
81-
ds = ds.sort("priority", reverse=True).take(top_k)
84+
ds = self.order_task(ds)
8285
# export to the target directory
8386
ds.to_json(os.path.join(self.config.output_dir, "output.jsonl")) # type: ignore [arg-type]
8487
return {"sample_num": ds.num_rows}
88+
89+
def order_task(self, dataset: Dataset) -> Dataset:
90+
"""
91+
Order the dataset with specified method.
92+
"""
93+
# check if priority field exists
94+
if "priority" not in dataset.features and self.order_method in {"sort", "folding"}:
95+
self.logger.warning(
96+
f'"priority" field not found for {self.order_method}. Use "keep" instead.'
97+
)
98+
self.order_method = "keep"
99+
100+
# get top-k
101+
top_k = self.config.top_k
102+
if top_k == -1:
103+
top_k = dataset.num_rows
104+
105+
if self.order_method == "keep":
106+
# keep the original order
107+
return dataset
108+
elif self.order_method == "shuffle":
109+
# shuffle the dataset
110+
return dataset.shuffle()
111+
elif self.order_method == "sort":
112+
# sort the dataset acording to priority
113+
return dataset.sort("priority", reverse=True).take(top_k)
114+
elif self.order_method == "folding":
115+
# folding the dataset to repeat the curriculum learning
116+
# Reference: https://arxiv.org/abs/2506.21545
117+
sorted_dataset = dataset.sort("priority", reverse=True).take(top_k)
118+
folding_layers = self.order_args.get("folding_layers", 3)
119+
folding_indices = []
120+
for j in range(folding_layers):
121+
partition = list(range(j, dataset.num_rows, folding_layers))
122+
folding_indices.extend(partition)
123+
return sorted_dataset.select(folding_indices)
124+
else:
125+
raise ValueError(f"Invalid order method: {self.order_method}")

trinity/service/data_juicer/server/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class DJConfig(BaseModel):
2424
target_fields: List[str] = [] # fields in the output dataset
2525
priority_weights: Dict[str, float] = {} # weights for priority computing
2626
top_k: int = -1 # number of samples to select after task pipeline. -1 means all
27+
order_method: Literal["keep", "shuffle", "sort", "folding"] = "sort"
28+
order_args: Dict = {}
2729

2830
@model_validator(mode="after")
2931
def check_dj_config(self):

0 commit comments

Comments
 (0)