Skip to content

Commit a63b5f4

Browse files
committed
* modified according to the latest buffer imp.
1 parent a1ce540 commit a63b5f4

File tree

8 files changed

+66
-47
lines changed

8 files changed

+66
-47
lines changed

tests/data/controllers/task_parser_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ def _run_test(self, rft_config, return_none=False):
4343
logger.info("None dj config.")
4444
else:
4545
self.assertIsNotNone(dj_config)
46-
op_weights = {}
47-
for op in dj_config.process:
48-
op_name = list(op.keys())[0]
49-
op_weights[op_name] = op[op_name]["op_weight"]
50-
logger.info(op_weights)
5146

5247
def test_instruction1(self):
5348
rft_config = DataPipelineConfig()

tests/data/core/dataset_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_rft_dataset_init(self):
5757
dataset = RftDataset(
5858
data_pipeline_config=self.data_pipeline_config, reward_schema="default"
5959
)
60+
dataset.read_from_buffer()
6061

6162
self.assertEqual(len(dataset), 10)
6263
self.assertIsInstance(dataset.reward_schema, RewardSchema)
@@ -65,6 +66,7 @@ def test_format_dataset(self):
6566
dataset = RftDataset(
6667
data_pipeline_config=self.data_pipeline_config, reward_schema="default"
6768
)
69+
dataset.read_from_buffer()
6870
original_data = dataset.data
6971
# no formatter
7072
dataset.format(formatters=[])

tests/data/core/formatter_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_init(self):
5454

5555
def test_transform(self):
5656
dataset = RftDataset(data_pipeline_config=self.data_config, reward_schema="default")
57+
dataset.read_from_buffer()
5758
formatter = BoxedMathAnswerFormatter(config=self.data_config.format)
5859
self.assertNotIn(formatter.config.response_key, dataset.data.column_names)
5960
dataset.format(formatter)

tests/data/processor/cleaner_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _run_test(self, tgt_list, weight=1, data_dist="gaussian"):
4141
)
4242

4343
dataset = RftDataset(self.rft_config.data_processor.task_pipeline)
44+
dataset.read_from_buffer()
4445
dataset = cleaner.process([dataset])
4546

4647
res_list = dataset.data.select_columns("text").to_list()

trinity/buffer/reader/file_reader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,14 @@ def read(self, strategy: Optional[ReadStrategy] = None):
235235
@FILE_READERS.register_module("raw")
236236
class RawDataReader(BufferReader):
237237
def __init__(self, meta: StorageConfig, config: BufferConfig):
238+
self.returned = False
238239
self.dataset = load_dataset(meta.path, name=meta.subset_name, split=meta.split)
239240

240241
def __len__(self):
241242
return len(self.dataset)
242243

243244
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
245+
if self.returned:
246+
raise StopIteration
247+
self.returned = True
244248
return self.dataset.to_list()

trinity/buffer/writer/file_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
2222
ext = os.path.splitext(meta.path)[-1]
2323
if ext != ".jsonl":
2424
raise ValueError(f"File path must end with .json or .jsonl, got {meta.path}")
25-
self.writer = jl.open(meta.path, mode="w")
25+
self.writer = jl.open(meta.path, mode="a")
2626

2727
def write(self, data: List) -> None:
2828
self.writer.write_all(data)

trinity/data/controllers/active_iterator.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -118,43 +118,53 @@ def run(self):
118118
traceback.print_exc()
119119
return 3, "DataCleaner loading failed."
120120

121-
# step 4. apply processors to calculate scores of different dimensions
122-
try:
123-
res_dataset = dataset
124-
if hit_cleaner:
125-
res_dataset = cleaner.process([res_dataset])
126-
if hit_synthesizer:
127-
res_dataset = synthesizer.process([res_dataset])
128-
if hit_human_annotator:
129-
res_dataset = human_annotator.process([res_dataset])
130-
except Exception:
131-
traceback.print_exc()
132-
return 4, "DataProcessors processing failed."
133-
134-
# step 5. calculate the average and final scores, including priority
135-
try:
136-
if hit_cleaner:
137-
scored_dataset = self._group_scores(res_dataset)
138-
scored_dataset = self._compute_priority_scores(scored_dataset)
139-
else:
140-
scored_dataset = res_dataset
141-
except Exception:
142-
traceback.print_exc()
143-
return 5, "Grouping and computing priority score failed."
144-
145-
# step 6. track lineage if they are changed
146-
try:
147-
res_dataset = scored_dataset
148-
except Exception:
149-
traceback.print_exc()
150-
return 6, "Tracking lineage failed."
151-
152-
# step 7. export the result to the output buffer
153-
try:
154-
res_dataset.write_to_buffer()
155-
except Exception:
156-
traceback.print_exc()
157-
return 7, "Exporting result to output buffer failed."
121+
while True:
122+
# step 4. load data from the input buffers for the next batch
123+
try:
124+
dataset.read_from_buffer()
125+
except StopIteration:
126+
break
127+
except Exception:
128+
traceback.print_exc()
129+
return 4, "RftDataset loading from buffers failed."
130+
131+
# step 5. apply processors to calculate scores of different dimensions
132+
try:
133+
res_dataset = dataset
134+
if hit_cleaner:
135+
res_dataset = cleaner.process([res_dataset])
136+
if hit_synthesizer:
137+
res_dataset = synthesizer.process([res_dataset])
138+
if hit_human_annotator:
139+
res_dataset = human_annotator.process([res_dataset])
140+
except Exception:
141+
traceback.print_exc()
142+
return 5, "DataProcessors processing failed."
143+
144+
# step 6. calculate the average and final scores, including priority
145+
try:
146+
if hit_cleaner:
147+
scored_dataset = self._group_scores(res_dataset)
148+
scored_dataset = self._compute_priority_scores(scored_dataset)
149+
else:
150+
scored_dataset = res_dataset
151+
except Exception:
152+
traceback.print_exc()
153+
return 6, "Grouping and computing priority score failed."
154+
155+
# step 7. track lineage if they are changed
156+
try:
157+
res_dataset = scored_dataset
158+
except Exception:
159+
traceback.print_exc()
160+
return 7, "Tracking lineage failed."
161+
162+
# step 8. export the result to the output buffer
163+
try:
164+
res_dataset.write_to_buffer()
165+
except Exception:
166+
traceback.print_exc()
167+
return 8, "Exporting result to output buffer failed."
158168

159169
return 0, "success"
160170

trinity/data/core/dataset.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ def __init__(
4545
input_buffer_configs = self.config.input_buffers
4646
if len(input_buffer_configs) == 0:
4747
raise ValueError("input_buffers is empty in data pipeline config")
48-
datasets = []
48+
self.buffers = []
4949
for input_buffer_config in input_buffer_configs:
50-
input_buffer = get_buffer_reader(input_buffer_config, self.buffer_config)
51-
datasets.append(Dataset.from_list(input_buffer.read()))
52-
self.data = concatenate_datasets(datasets)
50+
self.buffers.append(get_buffer_reader(input_buffer_config, self.buffer_config))
51+
self.data = Dataset.from_list([])
5352

5453
self.reward_schema = self._init_reward_schema(reward_schema)
5554
self.stats: Dict[str, Any] = {}
@@ -65,6 +64,12 @@ def format(
6564
for formatter in formatters:
6665
self.data = formatter(self.data, num_proc)
6766

67+
def read_from_buffer(self):
68+
datasets = []
69+
for buffer in self.buffers:
70+
datasets.append(Dataset.from_list(buffer.read()))
71+
self.data = concatenate_datasets(datasets)
72+
6873
def write_to_buffer(
6974
self, output_storage_config: StorageConfig = None, buffer_config: BufferConfig = None
7075
):
@@ -75,6 +80,7 @@ def write_to_buffer(
7580
output_buffer = get_buffer_writer(output_storage_config, buffer_config)
7681
output_buffer.write(self.data.to_list())
7782
output_buffer.finish()
83+
self.data = Dataset.from_list([])
7884

7985
def to_parquet(self, path: str):
8086
self.data.to_parquet(path)

0 commit comments

Comments
 (0)