Skip to content

Commit d5e46f3

Browse files
committed
* after pre-commit
1 parent d16f0a8 commit d5e46f3

File tree

9 files changed

+95
-47
lines changed

9 files changed

+95
-47
lines changed

examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ process:
88
api_or_hf_model: "qwen2.5-32b-instruct" # use "qwen2.5-32b-instruct" to calculate the quality scores.
99
min_score: 0.0
1010
input_keys: ["prompt_text", "prompt_text"] # set input_keys and field_names to the existing key names in gsm-8k. Here calculating the difficulty scores according to both questions and answers.
11-
field_names: ["prompt", "response"]
11+
field_names: ["prompt", "response"]

tests/common/experience_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import torch
77

8-
from trinity.common.experience import Experience, Experiences
98
from trinity.buffer.schema.sql_schema import ExperienceModel
9+
from trinity.common.experience import Experience, Experiences
1010

1111
db_url = os.path.join(os.path.dirname(__file__), "tmp", "test.db")
1212
dataset_path = os.path.join(os.path.dirname(__file__), "data")

trinity/cli/launcher.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,21 @@ def activate_data_processor(data_processor_url: str, config_path: str):
161161
logger.error(f"Failed to activate data module: {res['return_msg']}.")
162162
return
163163

164+
164165
def stop_data_processor(base_data_processor_url: str):
165166
"""Stop all pipelines in the data processor"""
166167
from trinity.cli.client import request
168+
167169
logger.info(f"Stopping all pipelines in {base_data_processor_url}...")
168-
res = request(url=f'{base_data_processor_url}/stop_all')
170+
res = request(url=f"{base_data_processor_url}/stop_all")
169171
if res["return_code"] != 0:
170172
logger.error(f"Failed to stop all data pipelines: {res['return_msg']}.")
171173
return
172174

173-
def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: DataProcessorPipelineType):
175+
176+
def validate_data_pipeline(
177+
data_pipeline_config: DataPipelineConfig, pipeline_type: DataProcessorPipelineType
178+
):
174179
"""
175180
Check if the data pipeline is valid. The config should:
176181
1. Non-empty input buffer
@@ -205,9 +210,7 @@ def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_ty
205210
# No special items need to be checked.
206211
pass
207212
else:
208-
logger.warning(
209-
f'Invalid pipeline type: {pipeline_type}..'
210-
)
213+
logger.warning(f"Invalid pipeline type: {pipeline_type}..")
211214
return False
212215
return True
213216

@@ -220,21 +223,27 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
220223
# try to activate task pipeline for raw data
221224
data_processor_config = config.data_processor
222225
if (
223-
data_processor_config.data_processor_url
224-
and data_processor_config.task_pipeline
225-
and validate_data_pipeline(data_processor_config.task_pipeline, DataProcessorPipelineType.TASK)
226+
data_processor_config.data_processor_url is not None
227+
and data_processor_config.task_pipeline is not None
228+
and validate_data_pipeline(
229+
data_processor_config.task_pipeline, DataProcessorPipelineType.TASK
230+
)
226231
):
227232
activate_data_processor(
228-
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.TASK.value}", config_path
233+
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.TASK.value}",
234+
config_path,
229235
)
230236
# try to activate experience pipeline for experiences
231237
if (
232-
data_processor_config.data_processor_url
233-
and data_processor_config.experience_pipeline
234-
and validate_data_pipeline(data_processor_config.experience_pipeline, DataProcessorPipelineType.EXPERIENCE)
238+
data_processor_config.data_processor_url is not None
239+
and data_processor_config.experience_pipeline is not None
240+
and validate_data_pipeline(
241+
data_processor_config.experience_pipeline, DataProcessorPipelineType.EXPERIENCE
242+
)
235243
):
236244
activate_data_processor(
237-
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}", config_path
245+
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}",
246+
config_path,
238247
)
239248
if dlc:
240249
from trinity.utils.dlc_utils import setup_ray_cluster
@@ -268,7 +277,8 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
268277
stop_ray_cluster(namespace=config.ray_namespace)
269278

270279
# stop all pipelines
271-
stop_data_processor(data_processor_config.data_processor_url)
280+
if data_processor_config.data_processor_url is not None:
281+
stop_data_processor(data_processor_config.data_processor_url)
272282

273283

274284
def studio(port: int = 8501):

trinity/common/config.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from trinity.common.constants import (
1010
EXPLORER_NAME,
1111
TRAINER_NAME,
12+
OpType,
1213
PromptType,
1314
ReadStrategy,
1415
StorageType,
1516
SyncMethod,
1617
TaskType,
17-
OpType
1818
)
1919
from trinity.utils.log import get_logger
2020

@@ -103,6 +103,7 @@ class StorageConfig:
103103
# ! DO NOT SET, automatically set corresponding to train/eval
104104
task_type: TaskType = TaskType.EXPLORE
105105

106+
106107
@dataclass
107108
class RewardShapingConfig:
108109
"""Config for reward shaping."""
@@ -111,6 +112,7 @@ class RewardShapingConfig:
111112
op_type: OpType = OpType.ADD
112113
weight: float = 1.0
113114

115+
114116
@dataclass
115117
class DataPipelineConfig:
116118
"""Config for data pipeline."""
@@ -513,23 +515,35 @@ def _check_buffer(self) -> None: # noqa: C901
513515
output_buffers = {}
514516
# - taskset
515517
if self.buffer.explorer_input.taskset.name:
516-
input_buffers[self.buffer.explorer_input.taskset.name] = self.buffer.explorer_input.taskset
518+
input_buffers[
519+
self.buffer.explorer_input.taskset.name
520+
] = self.buffer.explorer_input.taskset
517521
# - explorer output
518522
if self.buffer.explorer_output and self.buffer.explorer_output.name:
519523
output_buffers[self.buffer.explorer_output.name] = self.buffer.explorer_output
520524
# - trainer input: experience buffer
521-
if self.buffer.trainer_input.experience_buffer and self.buffer.trainer_input.experience_buffer.name:
522-
input_buffers[self.buffer.trainer_input.experience_buffer.name] = self.buffer.trainer_input.experience_buffer
525+
if (
526+
self.buffer.trainer_input.experience_buffer
527+
and self.buffer.trainer_input.experience_buffer.name
528+
):
529+
input_buffers[
530+
self.buffer.trainer_input.experience_buffer.name
531+
] = self.buffer.trainer_input.experience_buffer
523532
# - trainer input: sft warmup dataset
524-
if self.buffer.trainer_input.sft_warmup_dataset and self.buffer.trainer_input.sft_warmup_dataset.name:
525-
input_buffers[self.buffer.trainer_input.sft_warmup_dataset.name] = self.buffer.trainer_input.sft_warmup_dataset
533+
if (
534+
self.buffer.trainer_input.sft_warmup_dataset
535+
and self.buffer.trainer_input.sft_warmup_dataset.name
536+
):
537+
input_buffers[
538+
self.buffer.trainer_input.sft_warmup_dataset.name
539+
] = self.buffer.trainer_input.sft_warmup_dataset
526540

527541
# when experience pipeline is on, the explorer output and the
528542
# experience buffer of trainer input should be different
529543
if self.buffer.explorer_output == self.buffer.trainer_input.experience_buffer:
530544
raise ValueError(
531-
f"The explorer output buffer should be different from the experience buffer of the trainer input "
532-
f"when experience pipeline is provided."
545+
"The explorer output buffer should be different from the experience buffer of the trainer input "
546+
"when experience pipeline is provided."
533547
)
534548

535549
# NOTICE: For now, input/output buffers for data processors should come from output/input buffers of trinity
@@ -552,7 +566,9 @@ def _check_buffer(self) -> None: # noqa: C901
552566
f"input buffers of trinity."
553567
)
554568
else:
555-
self.data_processor.experience_pipeline.output_buffer = input_buffers[exp_pipeline_output_buffers.name]
569+
self.data_processor.experience_pipeline.output_buffer = input_buffers[
570+
exp_pipeline_output_buffers.name
571+
]
556572

557573
# set read_batch_size / pad_token_id / tokenizer_path
558574
self.buffer.read_batch_size = self.buffer.batch_size * self.algorithm.repeat_times

trinity/common/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,14 @@ class RunningStatus(Enum):
104104
WAITING_SYNC = "waiting_sync"
105105
STOPPED = "stopped"
106106

107+
107108
class DataProcessorPipelineType(Enum):
108109
"""Data processor pipeline type."""
109110

110111
EXPERIENCE = "experience_pipeline"
111112
TASK = "task_pipeline"
112113

114+
113115
class OpType(Enum):
114116
"""Operator type for reward shaping."""
115117

trinity/data/controllers/active_iterator.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
2-
import traceback
32
import threading
3+
import traceback
4+
from functools import partial
45
from numbers import Number
56
from typing import Any, Dict, List, Union
6-
from functools import partial
7-
from data_juicer.utils.constant import Fields
87

98
import ray
9+
from data_juicer.utils.constant import Fields
1010

1111
from trinity.common.config import BufferConfig, DataPipelineConfig, RewardShapingConfig
1212
from trinity.common.constants import DataProcessorPipelineType, OpType
@@ -102,7 +102,7 @@ def __init__(
102102
def run(self, thread_event: threading.Event = None):
103103
"""Run the active iterator."""
104104
# step 1. parse the dj config
105-
logger.info('Parsing the Data-Juicer config...')
105+
logger.info("Parsing the Data-Juicer config...")
106106
try:
107107
(
108108
dj_config,
@@ -115,15 +115,15 @@ def run(self, thread_event: threading.Event = None):
115115
return 1, "config parsing failed."
116116

117117
# step 2. prepare rft-dataset from the input buffers
118-
logger.info('Preparing Rft-Dataset from input buffers...')
118+
logger.info("Preparing Rft-Dataset from input buffers...")
119119
try:
120120
dataset = RftDataset(self.config, self.buffer_config)
121121
except Exception:
122122
traceback.print_exc()
123123
return 2, "RftDataset loading failed."
124124

125125
# step 3. load processor
126-
logger.info('Loading data processors...')
126+
logger.info("Loading data processors...")
127127
try:
128128
if hit_cleaner:
129129
cleaner = DataCleaner(
@@ -151,7 +151,7 @@ def run(self, thread_event: threading.Event = None):
151151
break
152152

153153
# step 4. load data from the input buffers for the next batch
154-
logger.info('Loading data from input buffers for the next batch...')
154+
logger.info("Loading data from input buffers for the next batch...")
155155
try:
156156
dataset.read_from_buffer()
157157
except StopIteration:
@@ -161,7 +161,7 @@ def run(self, thread_event: threading.Event = None):
161161
return 4, "RftDataset loading from buffers failed."
162162

163163
# step 5. apply processors to calculate scores of different dimensions
164-
logger.info('Applying data processors to calculate stats...')
164+
logger.info("Applying data processors to calculate stats...")
165165
try:
166166
res_dataset = dataset
167167
if hit_cleaner:
@@ -177,7 +177,7 @@ def run(self, thread_event: threading.Event = None):
177177
# step 6. calculate the average and final scores, including priority
178178
try:
179179
if hit_cleaner:
180-
logger.info('Calculating the average and final scores...')
180+
logger.info("Calculating the average and final scores...")
181181
scored_dataset = self._group_scores(res_dataset)
182182
scored_dataset = self._compute_priority_scores(scored_dataset)
183183
else:
@@ -188,7 +188,11 @@ def run(self, thread_event: threading.Event = None):
188188

189189
# step 7. reward shaping. Only available for experience pipeline and the reward shaping config is set
190190
try:
191-
if self.pipeline_type == DataProcessorPipelineType.EXPERIENCE and len(self.config.reward_shaping) > 0:
191+
if (
192+
self.pipeline_type == DataProcessorPipelineType.EXPERIENCE
193+
and self.config.reward_shaping is not None
194+
and len(self.config.reward_shaping) > 0
195+
):
192196
logger.info("Rewarding shaping...")
193197
reshaped_dataset = self._reward_shaping(scored_dataset)
194198
else:
@@ -215,7 +219,7 @@ def run(self, thread_event: threading.Event = None):
215219

216220
# step 10. export the result to the output buffer
217221
try:
218-
logger.info('Writing processed data to output buffer...')
222+
logger.info("Writing processed data to output buffer...")
219223
res_dataset.write_to_buffer()
220224
except Exception:
221225
traceback.print_exc()
@@ -325,13 +329,21 @@ def _reward_shaping_single(self, sample, reward_shaping_config: RewardShapingCon
325329
if tgt_stats not in sample[Fields.stats]:
326330
return sample
327331
if op_type == OpType.ADD:
328-
sample[self.config.format.reward_key] += reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
332+
sample[self.config.format.reward_key] += (
333+
reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
334+
)
329335
elif op_type == OpType.MUL:
330-
sample[self.config.format.reward_key] *= reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
336+
sample[self.config.format.reward_key] *= (
337+
reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
338+
)
331339
elif op_type == OpType.SUB:
332-
sample[self.config.format.reward_key] -= reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
340+
sample[self.config.format.reward_key] -= (
341+
reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
342+
)
333343
elif op_type == OpType.DIV:
334-
sample[self.config.format.reward_key] /= reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
344+
sample[self.config.format.reward_key] /= (
345+
reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
346+
)
335347
return sample
336348

337349
def _reward_shaping(self, rft_dataset: RftDataset) -> RftDataset:
@@ -342,7 +354,9 @@ def _reward_shaping(self, rft_dataset: RftDataset) -> RftDataset:
342354
# get reward shaping configs
343355
reward_shaping_configs = self.config.reward_shaping
344356
for reward_shaping_config in reward_shaping_configs:
345-
dataset = dataset.map(partial(self._reward_shaping_single, reward_shaping_config=reward_shaping_config))
357+
dataset = dataset.map(
358+
partial(self._reward_shaping_single, reward_shaping_config=reward_shaping_config)
359+
)
346360

347361
rft_dataset.data = dataset
348362
return rft_dataset

trinity/data/core/dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
from abc import ABC
2-
from dataclasses import dataclass, fields, asdict
2+
from dataclasses import asdict, dataclass, fields
33
from typing import Any, Dict, List, Optional, Union
44

55
import networkx as nx
66
from datasets import Dataset, concatenate_datasets
77

88
from trinity.buffer import get_buffer_reader, get_buffer_writer
9-
from trinity.common.config import BufferConfig, DataPipelineConfig, StorageConfig
9+
from trinity.common.config import BufferConfig, DataPipelineConfig
1010
from trinity.data.core.formatter import BaseDataFormatter
1111
from trinity.utils.log import get_logger
1212

1313
logger = get_logger(__name__)
1414

15+
1516
def dict_to_dataclass(cls, d):
1617
valid_keys = {f.name for f in fields(cls)}
1718
filtered = {k: v for k, v in d.items() if k in valid_keys}
1819
return cls(**filtered)
1920

21+
2022
@dataclass
2123
class RewardSchema:
2224
"""Schema for reward related fields"""

trinity/data/processors/cleaner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
dj_cfg: Optional[Namespace],
3636
clean_strategy: str = "iterative",
3737
min_size_ratio: PositiveFloat = None,
38-
data_dist: str = "gaussian",
38+
data_dist: Optional[str] = "gaussian",
3939
op_weights: dict = None,
4040
**kwargs,
4141
):

trinity/data/server.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
import fire
21
import threading
2+
from typing import List
3+
4+
import fire
35
import ray
46
from flask import Flask, jsonify, request
57
from markupsafe import escape
6-
from typing import List
78

89
app = Flask(__name__)
910

1011
APP_NAME = "data_processor"
1112

1213
EVNET_POOL: List[threading.Event] = []
1314

15+
1416
@app.route(f"/{APP_NAME}/<pipeline_type>", methods=["GET"])
1517
def data_processor(pipeline_type):
1618
from trinity.common.config import load_config
@@ -57,13 +59,15 @@ def data_processor(pipeline_type):
5759
EVNET_POOL.append(event)
5860
return jsonify({"return_code": 0, "message": "Experience pipeline starts successfully."})
5961

62+
6063
@app.route(f"/{APP_NAME}/stop_all", methods=["GET"])
6164
def stop_all():
6265
try:
6366
for event in EVNET_POOL:
6467
event.set()
65-
except:
68+
except Exception:
6669
import traceback
70+
6771
traceback.print_exc()
6872
return jsonify({"return_code": 1, "message": traceback.format_exc()})
6973
return jsonify({"return_code": 0, "message": "All data pipelines are stopped."})

0 commit comments

Comments
 (0)