Skip to content

Commit c93ff10

Browse files
committed
* fix wrong way to check the types of processors in task parser
* fix: make dirs when initializing file writer
1 parent 7511c96 commit c93ff10

File tree

7 files changed

+28
-13
lines changed

7 files changed

+28
-13
lines changed

tests/test_configs/active_iterator_test_cfg.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@ data_processor:
22
# basic info
33
task_pipeline:
44
input_buffers:
5-
- path: 'tests/test_data/test_10/'
5+
- name: 'raw_input'
6+
path: 'tests/test_data/test_10/'
7+
storage_type: 'file'
68
raw: true
9+
output_buffer:
10+
name: 'raw_output'
11+
path: './outputs/task_pipelien_output/processed.jsonl'
12+
storage_type: 'file'
713
format:
814
prompt_key: 'problem'
915
response_key: 'solution'

tests/test_configs/active_iterator_test_dj_cfg.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
project_name: 'demo-process'
22

3-
export_path: './outputs/demo-process/demo-processed.jsonl'
4-
53
text_keys: 'solution'
64

75
process:

trinity/buffer/writer/file_writer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
2626
ext = os.path.splitext(meta.path)[-1]
2727
if ext != ".jsonl" and ext != ".json":
2828
raise ValueError(f"File path must end with .json or .jsonl, got {meta.path}")
29+
path_dir = os.path.dirname(meta.path)
30+
os.makedirs(path_dir, exist_ok=True)
2931
self.file = open(meta.path, "a", encoding="utf-8")
3032
self.encoder = _Encoder(ensure_ascii=False)
3133

trinity/data/controllers/active_iterator.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,20 @@ def run(self):
159159
traceback.print_exc()
160160
return 7, "Tracking lineage failed."
161161

162-
# step 8. sort and export the result to the output buffer
162+
# step 8
163+
try:
164+
if "priority" in res_dataset.data.features:
165+
res_dataset.sort_by("priority", reverse=True)
166+
except Exception:
167+
traceback.print_exc()
168+
return 8, "Sorting results by priority failed."
169+
170+
# step 9. sort and export the result to the output buffer
163171
try:
164-
res_dataset.sort_by("priority", reverse=True)
165172
res_dataset.write_to_buffer()
166173
except Exception:
167174
traceback.print_exc()
168-
return 8, "Exporting result to output buffer failed."
175+
return 9, "Exporting result to output buffer failed."
169176

170177
return 0, "success"
171178

@@ -247,7 +254,7 @@ def _compute_combined_score(
247254
difficulty = stats.get("difficulty_score", 0.5)
248255
score += self.priority_weights["difficulty"] * difficulty
249256

250-
sample["priority"] = score
257+
sample["priority"] = [score]
251258
return sample
252259

253260
def _compute_diversity_score(self) -> float:

trinity/data/controllers/task_parser.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,15 @@ def _check_types_of_processors(self, dj_config):
170170
process_list = dj_config.get("process", [])
171171
for op in process_list:
172172
op_name = list(op.keys())[0]
173-
if op_name in DEFAULT_CLEANER:
174-
hit_cleaner = True
175-
elif op_name in DEFAULT_SYNTHESIZER:
173+
if op_name in DEFAULT_SYNTHESIZER:
176174
hit_synthesizer = True
177175
elif op_name in DEFAULT_HUMAN_ANNOTATOR:
178176
hit_human_annotator = True
177+
else:
178+
for dimension in DEFAULT_CLEANER:
179+
if op_name in DEFAULT_CLEANER[dimension]:
180+
hit_cleaner = True
181+
break
179182
return hit_cleaner, hit_synthesizer, hit_human_annotator
180183

181184
def _update_common_op_args(self, dj_config: Namespace, extra_op_args: Dict) -> Namespace:

trinity/data/processors/cleaner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,14 @@ def process(
166166
else:
167167
logger.info("Executing Data-Juicer analyzer...")
168168
analyzer = Analyzer(self.dj_cfg)
169-
analyzer.run(dataset)
169+
analyzer.run(dataset, skip_export=True)
170170
df = analyzer.overall_result
171171
mean_series = df[df.index == "mean"]
172172
stats_key_to_mean = mean_series.iloc[0, :].to_dict()
173173
std_series = df[df.index == "std"]
174174
stats_key_to_std = std_series.iloc[0, :].to_dict()
175175

176176
tmp_cfg = copy.deepcopy(self.dj_cfg)
177-
print(tmp_cfg)
178177
self.op_name_to_stats_key = StatsKeys.get_access_log(dj_cfg=tmp_cfg, dataset=dataset)
179178

180179
for try_idx in range(max_tries):

trinity/data/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def data_workflow(pipeline_type):
1616
pipeline_type = escape(pipeline_type)
1717
config = load_config(config_path)
1818

19-
pipeline_config = getattr(config, pipeline_type)
19+
pipeline_config = getattr(config.data_processor, pipeline_type)
2020
if pipeline_config is None:
2121
return jsonify(
2222
{

0 commit comments

Comments
 (0)