Skip to content

Commit 49b8a82

Browse files
committed
[pdf2vqa] 一个文件一个算子
1 parent 8427fbd commit 49b8a82

File tree

5 files changed

+140
-129
lines changed

5 files changed

+140
-129
lines changed

dataflow/operators/pdf2vqa/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import TYPE_CHECKING
22

33
if TYPE_CHECKING:
4-
from .generate.pdf2vqa_formatter import MinerU2LLMInputOperator, LLMOutputParser, QA_Merger
4+
from .generate.mineru_to_llm_input_operator import MinerU2LLMInputOperator
5+
from .generate.llm_output_parser import LLMOutputParser
6+
from .generate.qa_merger import QA_Merger
57

68

79
else:

dataflow/operators/pdf2vqa/generate/pdf2vqa_formatter.py renamed to dataflow/operators/pdf2vqa/generate/llm_output_parser.py

Lines changed: 0 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -8,73 +8,7 @@
88
from dataflow.utils.registry import OPERATOR_REGISTRY
99
from dataflow.utils.storage import DataFlowStorage
1010
from dataflow import get_logger
11-
from dataflow.utils.pdf2vqa.format_utils import merge_qa_pair, jsonl_to_md
1211

13-
@OPERATOR_REGISTRY.register()
14-
class MinerU2LLMInputOperator(OperatorABC):
15-
def __init__(self):
16-
pass
17-
18-
@staticmethod
19-
def get_desc(lang: str = "zh") -> str:
20-
if lang == 'zh':
21-
return (
22-
"MinerU格式转换为LLM输入格式算子。"
23-
"将MinerU生成的内容列表JSON文件转换为适合LLM处理的格式,"
24-
"包括展平列表项并重新编号。"
25-
)
26-
else:
27-
return (
28-
"Convert MinerU format to LLM input format operator."
29-
"Transforms the content list JSON file generated by MinerU into a format suitable for LLM processing,"
30-
"including flattening list items and re-indexing."
31-
)
32-
33-
def _convert_json(self, input_file, output_file):
34-
with open(input_file, 'r') as infile:
35-
data = list(json.load(infile))
36-
37-
new_data = []
38-
id = 0
39-
for item in data:
40-
item['id'] = id
41-
item.pop('bbox', None)
42-
item.pop('page_idx', None)
43-
if item.get('type','') == 'list':
44-
if item['sub_type'] == 'text':
45-
for idx, list_item in enumerate(item.get('list_items', [])):
46-
new_item = {
47-
'type': 'text',
48-
'text': list_item,
49-
'id': id + idx,
50-
}
51-
new_data.append(new_item)
52-
id += len(item.get('list_items', []))
53-
else:
54-
new_data.append(item)
55-
id += 1
56-
57-
with open(output_file, 'w') as outfile:
58-
json.dump(new_data, outfile, ensure_ascii=False)
59-
60-
def run(self, storage: DataFlowStorage,
61-
input_markdown_path_key,
62-
output_converted_layout_key,
63-
):
64-
dataframe = storage.read("dataframe")
65-
66-
for index, row in dataframe.iterrows():
67-
input_json_path = row[input_markdown_path_key].replace('.md', '_content_list.json')
68-
converted_path = input_json_path.replace('.json', '_converted.json')
69-
self._convert_json(input_json_path, converted_path)
70-
dataframe.at[index, output_converted_layout_key] = converted_path
71-
72-
with open(converted_path, 'r') as infile:
73-
data = json.load(infile)
74-
assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}"
75-
76-
storage.write(dataframe)
77-
7812
@OPERATOR_REGISTRY.register()
7913
class LLMOutputParser(OperatorABC):
8014
def __init__(self,
@@ -204,64 +138,4 @@ def run(self, storage: DataFlowStorage,
204138

205139
dataframe.loc[idx, output_qalist_path_key] = output_qalist_path
206140

207-
storage.write(dataframe)
208-
209-
@OPERATOR_REGISTRY.register()
210-
class QA_Merger(OperatorABC):
211-
def __init__(self, output_dir, strict_title_match=False):
212-
self.output_dir = output_dir
213-
self.strict_title_match = strict_title_match
214-
215-
@staticmethod
216-
def get_desc(lang: str = "zh") -> str:
217-
if lang == 'zh':
218-
return (
219-
"QA对合并算子。"
220-
"将问题和答案的QA列表进行合并,生成最终的QA对文件,"
221-
"并转换为Markdown格式。"
222-
)
223-
else:
224-
return (
225-
"QA pair merging operator."
226-
"Merges question and answer QA lists to generate final QA pair files,"
227-
"and converts them to Markdown format."
228-
)
229-
230-
def run(self, storage: DataFlowStorage,
231-
input_question_qalist_path_key,
232-
input_answer_qalist_path_key,
233-
input_name_key,
234-
output_merged_qalist_path_key,
235-
output_merged_md_path_key,
236-
output_qa_item_key="qa_item" # 新增:展开后的 QA 内容列名
237-
):
238-
dataframe = storage.read("dataframe")
239-
240-
# 为了能存储 list 对象,先初始化该列为 object 类型
241-
dataframe[output_qa_item_key] = None
242-
dataframe[output_qa_item_key] = dataframe[output_qa_item_key].astype(object)
243-
244-
for idx, row in dataframe.iterrows():
245-
question_qalist_path = row[input_question_qalist_path_key]
246-
answer_qalist_path = row[input_answer_qalist_path_key]
247-
name = row[input_name_key]
248-
249-
output_merged_qalist_path = os.path.join(self.output_dir, name, "merged_qa_pairs.jsonl")
250-
merge_qa_pair(question_qalist_path, answer_qalist_path, output_merged_qalist_path, strict_title_match=self.strict_title_match)
251-
252-
output_merged_md_path = os.path.join(self.output_dir, name, "merged_qa_pairs.md")
253-
jsonl_to_md(output_merged_qalist_path, output_merged_md_path)
254-
255-
qa_pairs = []
256-
if os.path.exists(output_merged_qalist_path):
257-
with open(output_merged_qalist_path, 'r', encoding='utf-8') as f:
258-
qa_pairs = [json.loads(line) for line in f]
259-
260-
dataframe.at[idx, output_qa_item_key] = qa_pairs
261-
262-
dataframe.loc[idx, output_merged_qalist_path_key] = output_merged_qalist_path
263-
dataframe.loc[idx, output_merged_md_path_key] = output_merged_md_path
264-
265-
dataframe = dataframe.explode(output_qa_item_key).reset_index(drop=True)
266-
267141
storage.write(dataframe)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import json
2+
from dataflow.core import OperatorABC
3+
from dataflow.utils.registry import OPERATOR_REGISTRY
4+
from dataflow.utils.storage import DataFlowStorage
5+
6+
@OPERATOR_REGISTRY.register()
7+
class MinerU2LLMInputOperator(OperatorABC):
8+
def __init__(self):
9+
pass
10+
11+
@staticmethod
12+
def get_desc(lang: str = "zh") -> str:
13+
if lang == 'zh':
14+
return (
15+
"MinerU格式转换为LLM输入格式算子。"
16+
"将MinerU生成的内容列表JSON文件转换为适合LLM处理的格式,"
17+
"包括展平列表项并重新编号。"
18+
)
19+
else:
20+
return (
21+
"Convert MinerU format to LLM input format operator."
22+
"Transforms the content list JSON file generated by MinerU into a format suitable for LLM processing,"
23+
"including flattening list items and re-indexing."
24+
)
25+
26+
def _convert_json(self, input_file, output_file):
27+
with open(input_file, 'r') as infile:
28+
data = list(json.load(infile))
29+
30+
new_data = []
31+
id = 0
32+
for item in data:
33+
item['id'] = id
34+
item.pop('bbox', None)
35+
item.pop('page_idx', None)
36+
if item.get('type','') == 'list':
37+
if item['sub_type'] == 'text':
38+
for idx, list_item in enumerate(item.get('list_items', [])):
39+
new_item = {
40+
'type': 'text',
41+
'text': list_item,
42+
'id': id + idx,
43+
}
44+
new_data.append(new_item)
45+
id += len(item.get('list_items', []))
46+
else:
47+
new_data.append(item)
48+
id += 1
49+
50+
with open(output_file, 'w') as outfile:
51+
json.dump(new_data, outfile, ensure_ascii=False)
52+
53+
def run(self, storage: DataFlowStorage,
54+
input_markdown_path_key,
55+
output_converted_layout_key,
56+
):
57+
dataframe = storage.read("dataframe")
58+
59+
for index, row in dataframe.iterrows():
60+
input_json_path = row[input_markdown_path_key].replace('.md', '_content_list.json')
61+
converted_path = input_json_path.replace('.json', '_converted.json')
62+
self._convert_json(input_json_path, converted_path)
63+
dataframe.at[index, output_converted_layout_key] = converted_path
64+
65+
with open(converted_path, 'r') as infile:
66+
data = json.load(infile)
67+
assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}"
68+
69+
storage.write(dataframe)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
import json
3+
from dataflow.core import OperatorABC
4+
from dataflow.utils.registry import OPERATOR_REGISTRY
5+
from dataflow.utils.storage import DataFlowStorage
6+
from dataflow.utils.pdf2vqa.format_utils import merge_qa_pair, jsonl_to_md
7+
8+
@OPERATOR_REGISTRY.register()
9+
class QA_Merger(OperatorABC):
10+
def __init__(self, output_dir, strict_title_match=False):
11+
self.output_dir = output_dir
12+
self.strict_title_match = strict_title_match
13+
14+
@staticmethod
15+
def get_desc(lang: str = "zh") -> str:
16+
if lang == 'zh':
17+
return (
18+
"QA对合并算子。"
19+
"将问题和答案的QA列表进行合并,生成最终的QA对文件,"
20+
"并转换为Markdown格式。"
21+
)
22+
else:
23+
return (
24+
"QA pair merging operator."
25+
"Merges question and answer QA lists to generate final QA pair files,"
26+
"and converts them to Markdown format."
27+
)
28+
29+
def run(self, storage: DataFlowStorage,
30+
input_question_qalist_path_key,
31+
input_answer_qalist_path_key,
32+
input_name_key,
33+
output_merged_qalist_path_key,
34+
output_merged_md_path_key,
35+
output_qa_item_key="qa_item" # 新增:展开后的 QA 内容列名
36+
):
37+
dataframe = storage.read("dataframe")
38+
39+
# 为了能存储 list 对象,先初始化该列为 object 类型
40+
dataframe[output_qa_item_key] = None
41+
dataframe[output_qa_item_key] = dataframe[output_qa_item_key].astype(object)
42+
43+
for idx, row in dataframe.iterrows():
44+
question_qalist_path = row[input_question_qalist_path_key]
45+
answer_qalist_path = row[input_answer_qalist_path_key]
46+
name = row[input_name_key]
47+
48+
output_merged_qalist_path = os.path.join(self.output_dir, name, "merged_qa_pairs.jsonl")
49+
merge_qa_pair(question_qalist_path, answer_qalist_path, output_merged_qalist_path, strict_title_match=self.strict_title_match)
50+
51+
output_merged_md_path = os.path.join(self.output_dir, name, "merged_qa_pairs.md")
52+
jsonl_to_md(output_merged_qalist_path, output_merged_md_path)
53+
54+
qa_pairs = []
55+
if os.path.exists(output_merged_qalist_path):
56+
with open(output_merged_qalist_path, 'r', encoding='utf-8') as f:
57+
qa_pairs = [json.loads(line) for line in f]
58+
59+
dataframe.at[idx, output_qa_item_key] = qa_pairs
60+
61+
dataframe.loc[idx, output_merged_qalist_path_key] = output_merged_qalist_path
62+
dataframe.loc[idx, output_merged_md_path_key] = output_merged_md_path
63+
64+
dataframe = dataframe.explode(output_qa_item_key).reset_index(drop=True)
65+
66+
storage.write(dataframe)

dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dataflow.pipeline import PipelineABC
99
from dataflow.prompts.pdf2vqa import QAExtractPrompt
1010

11-
class VQA_extract_optimized_pipeline(PipelineABC):
11+
class PDF_VQA_extract_optimized_pipeline(PipelineABC):
1212
def __init__(self):
1313
super().__init__()
1414
self.storage = FileStorage(
@@ -102,6 +102,6 @@ def forward(self):
102102
if __name__ == "__main__":
103103
# jsonl中每一行包含question_pdf_path, answer_pdf_path, name (math1, math2, physics1, chemistry1, ...)
104104
# 如果question和answer在同一份pdf中,请将question_pdf_path和answer_pdf_path设置为相同的路径,会自动切换为interleaved模式
105-
pipeline = VQA_extract_optimized_pipeline()
105+
pipeline = PDF_VQA_extract_optimized_pipeline()
106106
pipeline.compile()
107107
pipeline.forward()

0 commit comments

Comments
 (0)