Skip to content

Commit a1ce540

Browse files
committed
* after pre-commit
1 parent e0874db commit a1ce540

File tree

9 files changed

+119
-97
lines changed

9 files changed

+119
-97
lines changed

tests/data/core/dataset_test.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import unittest
55

66
from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig
7-
from trinity.common.rewards import AccuracyReward
8-
from trinity.common.workflows import MathWorkflow, SimpleWorkflow
97
from trinity.data.core.dataset import RewardSchema, RftDataset
108
from trinity.data.core.formatter import BoxedMathAnswerFormatter, RLHFFormatter
119

@@ -15,33 +13,37 @@ class TestRftDataset(unittest.TestCase):
1513

1614
def setUp(self) -> None:
1715
self.data_pipeline_config = DataPipelineConfig(
18-
input_buffers=[StorageConfig(
19-
path=os.path.join(
20-
os.path.dirname(os.path.realpath(__file__)),
21-
"..",
22-
"..",
23-
"test_data",
24-
"test_10",
25-
),
26-
raw=True,
27-
)],
16+
input_buffers=[
17+
StorageConfig(
18+
path=os.path.join(
19+
os.path.dirname(os.path.realpath(__file__)),
20+
"..",
21+
"..",
22+
"test_data",
23+
"test_10",
24+
),
25+
raw=True,
26+
)
27+
],
2828
format=FormatConfig(
2929
prompt_key="problem",
3030
response_key="solution",
3131
solution_key="solution",
3232
),
3333
)
3434
self.data_pipeline_config_sample_level_setting = DataPipelineConfig(
35-
input_buffers=[StorageConfig(
36-
path=os.path.join(
37-
os.path.dirname(os.path.realpath(__file__)),
38-
"..",
39-
"..",
40-
"test_data",
41-
"test_10_with_rewfn_workflow",
42-
),
43-
raw=True,
44-
)],
35+
input_buffers=[
36+
StorageConfig(
37+
path=os.path.join(
38+
os.path.dirname(os.path.realpath(__file__)),
39+
"..",
40+
"..",
41+
"test_data",
42+
"test_10_with_rewfn_workflow",
43+
),
44+
raw=True,
45+
)
46+
],
4547
format=FormatConfig(
4648
prompt_key="problem",
4749
response_key="solution",
@@ -52,13 +54,17 @@ def setUp(self) -> None:
5254
)
5355

5456
def test_rft_dataset_init(self):
55-
dataset = RftDataset(data_pipeline_config=self.data_pipeline_config, reward_schema="default")
57+
dataset = RftDataset(
58+
data_pipeline_config=self.data_pipeline_config, reward_schema="default"
59+
)
5660

5761
self.assertEqual(len(dataset), 10)
5862
self.assertIsInstance(dataset.reward_schema, RewardSchema)
5963

6064
def test_format_dataset(self):
61-
dataset = RftDataset(data_pipeline_config=self.data_pipeline_config, reward_schema="default")
65+
dataset = RftDataset(
66+
data_pipeline_config=self.data_pipeline_config, reward_schema="default"
67+
)
6268
original_data = dataset.data
6369
# no formatter
6470
dataset.format(formatters=[])

tests/data/core/formatter_test.py

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@ class TestBoxedMathDataset(unittest.TestCase):
1919

2020
def setUp(self) -> None:
2121
self.data_config = DataPipelineConfig(
22-
input_buffers=[StorageConfig(
23-
path=os.path.join(
24-
os.path.dirname(os.path.realpath(__file__)),
25-
"..",
26-
"..",
27-
"test_data",
28-
"test_10",
29-
),
30-
raw=True,
31-
)],
22+
input_buffers=[
23+
StorageConfig(
24+
path=os.path.join(
25+
os.path.dirname(os.path.realpath(__file__)),
26+
"..",
27+
"..",
28+
"test_data",
29+
"test_10",
30+
),
31+
raw=True,
32+
)
33+
],
3234
format=FormatConfig(
3335
prompt_key="problem",
3436
response_key="answer",
@@ -63,16 +65,18 @@ class TestRLHFFormatter(unittest.TestCase):
6365

6466
def setUp(self) -> None:
6567
self.data_config = DataPipelineConfig(
66-
input_buffers=[StorageConfig(
67-
path=os.path.join(
68-
os.path.dirname(os.path.realpath(__file__)),
69-
"..",
70-
"..",
71-
"test_data",
72-
"test_10",
73-
),
74-
raw=True,
75-
)],
68+
input_buffers=[
69+
StorageConfig(
70+
path=os.path.join(
71+
os.path.dirname(os.path.realpath(__file__)),
72+
"..",
73+
"..",
74+
"test_data",
75+
"test_10",
76+
),
77+
raw=True,
78+
)
79+
],
7680
format=FormatConfig(
7781
prompt_key="problem",
7882
chat_template="User: {}\nAssistant: ",
@@ -114,16 +118,18 @@ class TestRewardFormatter(unittest.TestCase):
114118

115119
def setUp(self) -> None:
116120
self.data_config = DataPipelineConfig(
117-
input_buffers=[StorageConfig(
118-
path=os.path.join(
119-
os.path.dirname(os.path.realpath(__file__)),
120-
"..",
121-
"..",
122-
"test_data",
123-
"test_10",
124-
),
125-
raw=True,
126-
)],
121+
input_buffers=[
122+
StorageConfig(
123+
path=os.path.join(
124+
os.path.dirname(os.path.realpath(__file__)),
125+
"..",
126+
"..",
127+
"test_data",
128+
"test_10",
129+
),
130+
raw=True,
131+
)
132+
],
127133
format=FormatConfig(
128134
prompt_key="problem",
129135
chosen_key="chosen",
@@ -174,16 +180,18 @@ class TestSFTFormatter(unittest.TestCase):
174180

175181
def setUp(self) -> None:
176182
self.data_config = DataPipelineConfig(
177-
input_buffers=[StorageConfig(
178-
path=os.path.join(
179-
os.path.dirname(os.path.realpath(__file__)),
180-
"..",
181-
"..",
182-
"test_data",
183-
"test_10",
184-
),
185-
raw=True,
186-
)],
183+
input_buffers=[
184+
StorageConfig(
185+
path=os.path.join(
186+
os.path.dirname(os.path.realpath(__file__)),
187+
"..",
188+
"..",
189+
"test_data",
190+
"test_10",
191+
),
192+
raw=True,
193+
)
194+
],
187195
format=FormatConfig(
188196
prompt_key="problem",
189197
response_key="answer",
@@ -230,16 +238,18 @@ class TestComposedFormatter(unittest.TestCase):
230238

231239
def setUp(self) -> None:
232240
self.data_config = DataPipelineConfig(
233-
input_buffers=[StorageConfig(
234-
path=os.path.join(
235-
os.path.dirname(os.path.realpath(__file__)),
236-
"..",
237-
"..",
238-
"test_data",
239-
"test_10",
240-
),
241-
raw=True,
242-
)],
241+
input_buffers=[
242+
StorageConfig(
243+
path=os.path.join(
244+
os.path.dirname(os.path.realpath(__file__)),
245+
"..",
246+
"..",
247+
"test_data",
248+
"test_10",
249+
),
250+
raw=True,
251+
)
252+
],
243253
format=FormatConfig(
244254
prompt_key="problem",
245255
response_key="answer",

trinity/buffer/reader/file_reader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,11 @@ def read(self, strategy: Optional[ReadStrategy] = None):
231231
self.index = 0
232232
return task
233233

234+
234235
@FILE_READERS.register_module("raw")
235236
class RawDataReader(BufferReader):
236-
237237
def __init__(self, meta: StorageConfig, config: BufferConfig):
238-
self.dataset = load_dataset(
239-
meta.path, name=meta.subset_name, split=meta.split
240-
)
238+
self.dataset = load_dataset(meta.path, name=meta.subset_name, split=meta.split)
241239

242240
def __len__(self):
243241
return len(self.dataset)

trinity/buffer/writer/file_writer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Writer of the File buffer."""
22
import os
33
from typing import List
4+
45
import jsonlines as jl
56

67
from trinity.buffer.buffer_writer import BufferWriter
@@ -16,13 +17,15 @@ class RawFileWriter(BufferWriter):
1617

1718
def __init__(self, meta: StorageConfig, config: BufferConfig):
1819
assert meta.storage_type == StorageType.FILE
19-
ext = os.path.splitext(meta.path)
20-
if ext != '.jsonl':
20+
if meta.path is None:
21+
raise ValueError("File path cannot be None for RawFileWriter")
22+
ext = os.path.splitext(meta.path)[-1]
23+
if ext != ".jsonl":
2124
raise ValueError(f"File path must end with .json or .jsonl, got {meta.path}")
22-
self.writer = jl.open(meta.path, mode='w')
25+
self.writer = jl.open(meta.path, mode="w")
2326

2427
def write(self, data: List) -> None:
2528
self.writer.write_all(data)
2629

2730
def finish(self):
28-
self.writer.close()
31+
self.writer.close()

trinity/cli/launcher.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
166166
# try to activate task pipeline for raw data
167167
data_processor_config = config.data_processor
168168
if data_processor_config.data_workflow_url and data_processor_config.task_pipeline:
169-
activate_data_module(f'{data_processor_config.data_workflow_url}/task_pipeline', config_path)
169+
activate_data_module(
170+
f"{data_processor_config.data_workflow_url}/task_pipeline", config_path
171+
)
170172
ray_namespace = f"{config.project}-{config.name}"
171173
if dlc:
172174
from trinity.utils.dlc_utils import setup_ray_cluster

trinity/common/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class StorageConfig:
9898
# ! DO NOT SET, automatically set corresponding to train/eval
9999
task_type: TaskType = TaskType.EXPLORE
100100

101+
101102
@dataclass
102103
class DataPipelineConfig:
103104
"""Config for data pipeline."""
@@ -122,6 +123,7 @@ class DataPipelineConfig:
122123
priority_weights: Optional[Dict[str, float]] = None
123124
data_dist: Optional[str] = "gaussian" # one of ["gaussian", "uniform"]
124125

126+
125127
@dataclass
126128
class DataProcessorConfig:
127129
"""Data-Juicer config"""

trinity/data/controllers/active_iterator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import ray
66

7-
from trinity.common.config import DataPipelineConfig, BufferConfig
7+
from trinity.common.config import BufferConfig, DataPipelineConfig
88
from trinity.data.controllers.default_ops import DIMENSION_STATS_KEYS
99
from trinity.data.controllers.task_parser import DataTaskParser
1010
from trinity.data.core.dataset import RftDataset
@@ -25,10 +25,7 @@ def __init__(
2525
):
2626
self.config = config
2727
self.buffer_config = buffer_config
28-
if (
29-
self.config.agent_model_name is not None
30-
and self.config.agent_model_config is not None
31-
):
28+
if self.config.agent_model_name is not None and self.config.agent_model_config is not None:
3229
# get the api key
3330
api_key = os.environ.get("OPENAI_API_KEY")
3431
# initialize the agent

trinity/data/core/dataset.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import networkx as nx
66
from datasets import Dataset, concatenate_datasets
77

8-
from trinity.common.config import DataPipelineConfig, BufferConfig, StorageConfig
98
from trinity.buffer import get_buffer_reader, get_buffer_writer
9+
from trinity.common.config import BufferConfig, DataPipelineConfig, StorageConfig
1010
from trinity.data.core.formatter import BaseDataFormatter
1111

1212

@@ -45,11 +45,11 @@ def __init__(
4545
input_buffer_configs = self.config.input_buffers
4646
if len(input_buffer_configs) == 0:
4747
raise ValueError("input_buffers is empty in data pipeline config")
48-
self.data = []
48+
datasets = []
4949
for input_buffer_config in input_buffer_configs:
5050
input_buffer = get_buffer_reader(input_buffer_config, self.buffer_config)
51-
self.data.append(Dataset.from_list(input_buffer.read()))
52-
self.data = concatenate_datasets(self.data)
51+
datasets.append(Dataset.from_list(input_buffer.read()))
52+
self.data = concatenate_datasets(datasets)
5353

5454
self.reward_schema = self._init_reward_schema(reward_schema)
5555
self.stats: Dict[str, Any] = {}
@@ -65,7 +65,9 @@ def format(
6565
for formatter in formatters:
6666
self.data = formatter(self.data, num_proc)
6767

68-
def write_to_buffer(self, output_storage_config: StorageConfig = None, buffer_config: BufferConfig = None):
68+
def write_to_buffer(
69+
self, output_storage_config: StorageConfig = None, buffer_config: BufferConfig = None
70+
):
6971
if output_storage_config is None:
7072
output_storage_config = self.config.output_buffer
7173
if buffer_config is None:

trinity/data/server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ def data_workflow(pipeline_type):
1818

1919
pipeline_config = getattr(config, pipeline_type)
2020
if pipeline_config is None:
21-
return jsonify({
22-
"return_code": -1,
23-
"message": f"{pipeline_type} is not supported or the corresponding config is empty"
24-
})
21+
return jsonify(
22+
{
23+
"return_code": -1,
24+
"message": f"{pipeline_type} is not supported or the corresponding config is empty",
25+
}
26+
)
2527

2628
iterator = DataActiveIterator(pipeline_config, config.buffer)
2729
ret, msg = iterator.run()

0 commit comments

Comments
 (0)