Skip to content

Commit d250586

Browse files
authored
Merge pull request #443 from fatty-belly/pdf2vqa_dev
PDF2VQA 重构
2 parents 2754519 + 49b8a82 commit d250586

File tree

12 files changed

+538
-569
lines changed

12 files changed

+538
-569
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "subject": "math", "output_dir": "../vqa_output_test/math1"}
2-
{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "subject": "math", "output_dir": "../vqa_output_test/math2"}
1+
{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "name": "math1"}
2+
{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "name": "math2"}

dataflow/operators/core_text/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
if TYPE_CHECKING:
44
from .generate.prompted_generator import PromptedGenerator
5+
from .generate.chunked_prompted_generator import ChunkedPromptedGenerator
56
from .generate.format_str_prompted_generator import FormatStrPromptedGenerator
67
from .generate.random_domain_knowledge_row_generator import RandomDomainKnowledgeRowGenerator
78
from .generate.text2qa_generator import Text2QAGenerator
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import pandas as pd
2+
import tiktoken
3+
from dataflow.utils.registry import OPERATOR_REGISTRY
4+
from dataflow import get_logger
5+
from pathlib import Path
6+
7+
from dataflow.utils.storage import DataFlowStorage
8+
from dataflow.core import OperatorABC
9+
from dataflow.core import LLMServingABC
10+
11+
@OPERATOR_REGISTRY.register()
12+
class ChunkedPromptedGenerator(OperatorABC):
13+
"""
14+
基于Prompt的生成算子,支持自动chunk输入。
15+
- 使用tiktoken或HuggingFace的AutoTokenizer计算token数量;
16+
- 若输入超过max_chunk_len,采用递归二分法切分;
17+
- 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径;
18+
- 生成结果是以separator拼接的字符串。
19+
"""
20+
21+
def __init__(
22+
self,
23+
llm_serving: LLMServingABC,
24+
system_prompt: str = "You are a helpful agent.",
25+
json_schema: dict = None,
26+
max_chunk_len: int = 128000,
27+
enc = tiktoken.get_encoding("cl100k_base"), # 支持len(enc.encode(text))的tokenizer都可以,比如tiktoken或HuggingFace的AutoTokenizer
28+
seperator: str = "\n",
29+
):
30+
self.logger = get_logger()
31+
self.llm_serving = llm_serving
32+
self.system_prompt = system_prompt
33+
self.json_schema = json_schema
34+
self.max_chunk_len = max_chunk_len
35+
self.enc = enc
36+
self.separator = seperator
37+
38+
@staticmethod
39+
def get_desc(lang: str = "zh"):
40+
if lang == "zh":
41+
return (
42+
"基于提示词的生成算子,支持长文本自动分chunk。"
43+
"采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。"
44+
"从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。"
45+
"多个生成结果以separator拼接成最终输出字符串。"
46+
"输入参数:\n"
47+
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
48+
"- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n"
49+
"- max_chunk_len:单个chunk的最大token长度,默认为128000\n"
50+
"- input_path_key:输入文件路径字段名,默认为'input_path'\n"
51+
"- output_path_key:输出文件路径字段名,默认为'output_path'\n"
52+
"- json_schema:可选,生成结果的JSON Schema约束\n"
53+
"- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n"
54+
"- separator:chunk结果拼接分隔符,默认为换行符\n"
55+
)
56+
else:
57+
return (
58+
"Prompt-based generator with recursive chunk splitting."
59+
"Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens."
60+
"Reads content from specified input file paths and saves generated results to designated output file paths."
61+
"Multiple generated results are joined as a string using the specified separator."
62+
"Input Parameters:\n"
63+
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
64+
"- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n"
65+
"- max_chunk_len: Maximum token length per chunk, default is 128000\n"
66+
"- input_path_key: Field name for input file path, default is 'input_path'\n"
67+
"- output_path_key: Field name for output file path, default is 'output_path'\n"
68+
"- json_schema: Optional JSON Schema constraint for generated results\n"
69+
"- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n"
70+
"- separator: Separator for chunk results, default is newline character\n"
71+
)
72+
73+
# === token计算 ===
74+
def _count_tokens(self, text: str) -> int:
75+
return len(self.enc.encode(text))
76+
77+
# === 递归二分分chunk ===
78+
def _split_recursive(self, text: str) -> list[str]:
79+
"""递归地将文本拆分为不超过max_chunk_len的多个chunk"""
80+
token_len = self._count_tokens(text)
81+
if token_len <= self.max_chunk_len:
82+
return [text]
83+
else:
84+
mid = len(text) // 2
85+
left, right = text[:mid], text[mid:]
86+
return self._split_recursive(left) + self._split_recursive(right)
87+
88+
def run(
89+
self,
90+
storage: DataFlowStorage,
91+
input_path_key,
92+
output_path_key,
93+
):
94+
self.logger.info("Running ChunkedPromptedGenerator...")
95+
dataframe = storage.read("dataframe")
96+
self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.")
97+
98+
all_generated_results = []
99+
100+
all_llm_inputs = []
101+
row_chunk_map = [] # 记录每个row对应的chunk数量
102+
103+
# === 先收集所有chunk ===
104+
for i, row in dataframe.iterrows():
105+
raw_content = Path(row[input_path_key]).read_text(encoding='utf-8')
106+
107+
chunks = self._split_recursive(raw_content)
108+
self.logger.info(f"Row {i}: split into {len(chunks)} chunks")
109+
110+
system_prompt = self.system_prompt + "\n"
111+
llm_inputs = [system_prompt + chunk for chunk in chunks]
112+
all_llm_inputs.extend(llm_inputs)
113+
row_chunk_map.append(len(chunks))
114+
115+
# === 一次性并发调用 ===
116+
self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate")
117+
118+
try:
119+
if self.json_schema:
120+
all_responses = self.llm_serving.generate_from_input(
121+
all_llm_inputs, json_schema=self.json_schema
122+
)
123+
else:
124+
all_responses = self.llm_serving.generate_from_input(all_llm_inputs)
125+
except Exception as e:
126+
self.logger.error(f"Global generation failed: {e}")
127+
all_generated_results = [[] for _ in range(len(dataframe))]
128+
else:
129+
# === 按row重新划分responses ===
130+
all_generated_results = []
131+
idx = 0
132+
for num_chunks in row_chunk_map:
133+
if num_chunks == 0:
134+
all_generated_results.append([])
135+
else:
136+
all_generated_results.append(all_responses[idx:idx + num_chunks])
137+
idx += num_chunks
138+
139+
for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results):
140+
output_path = row[input_path_key].split('.')[0] + '_llm_output.txt'
141+
with open(output_path, 'w', encoding='utf-8') as f:
142+
f.write(self.separator.join(gen_results))
143+
dataframe.at[i, output_path_key] = output_path
144+
145+
output_file = storage.write(dataframe)
146+
self.logger.info(f"Generation complete. Output saved to {output_file}")
147+
return output_path_key

dataflow/operators/core_text/generate/prompted_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pandas as pd
22
from dataflow.utils.registry import OPERATOR_REGISTRY
33
from dataflow import get_logger
4+
from pathlib import Path
45

56
from dataflow.utils.storage import DataFlowStorage
67
from dataflow.core import OperatorABC

dataflow/operators/pdf2vqa/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import TYPE_CHECKING
22

33
if TYPE_CHECKING:
4-
from .generate.vqa_extractor import VQAExtractor
4+
from .generate.mineru_to_llm_input_operator import MinerU2LLMInputOperator
5+
from .generate.llm_output_parser import LLMOutputParser
6+
from .generate.qa_merger import QA_Merger
57

68

79
else:
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import os
2+
import json
3+
import re
4+
import shutil
5+
from pathlib import Path
6+
from typing import Literal
7+
from dataflow.core import OperatorABC
8+
from dataflow.utils.registry import OPERATOR_REGISTRY
9+
from dataflow.utils.storage import DataFlowStorage
10+
from dataflow import get_logger
11+
12+
@OPERATOR_REGISTRY.register()
13+
class LLMOutputParser(OperatorABC):
14+
def __init__(self,
15+
mode: Literal['question', 'answer'],
16+
output_dir,
17+
intermediate_dir: str = "intermediate",
18+
):
19+
self.logger = get_logger()
20+
self.mode = mode
21+
self.output_dir = output_dir
22+
self.intermediate_dir = intermediate_dir
23+
24+
@staticmethod
25+
def get_desc(lang: str = "zh") -> str:
26+
if lang == 'zh':
27+
return (
28+
"LLM输出解析算子。"
29+
"将LLM生成的包含题目和答案ID的响应文本,"
30+
"转换为结构化的QA列表,并复制相关图片到输出目录。"
31+
)
32+
else:
33+
return (
34+
"LLM output parsing operator."
35+
"Converts LLM-generated response text containing question and answer IDs"
36+
"into a structured QA list and copies related images to the output directory."
37+
)
38+
39+
def _id_to_text(self, input_ids, input_json, image_prefix="images"):
40+
texts = []
41+
id_list = input_ids.replace(' ', '').split(',')
42+
for id in id_list:
43+
try:
44+
int(id)
45+
except:
46+
continue
47+
if int(id) < len(input_json):
48+
try:
49+
item = input_json[int(id)]
50+
except:
51+
continue
52+
if 'text' in item:
53+
texts.append(item['text'])
54+
elif 'img_path' in item:
55+
try:
56+
img_path = item.get('img_path', '')
57+
img_name = os.path.basename(img_path)
58+
new_path = f"{image_prefix}/{img_name}"
59+
texts.append(f"![{' '.join(item.get('image_caption','image'))}]({new_path})")
60+
except:
61+
pass
62+
elif item.get('type','') == 'list':
63+
if item['sub_type'] == 'text':
64+
try:
65+
texts.append(input_json[int(id)]['list_items'].pop(0))
66+
except:
67+
pass
68+
return '\n'.join(texts)
69+
70+
def _convert_response(self, input_response, input_json_path, image_prefix="images"):
71+
qa_list = []
72+
with open(input_json_path, 'r') as infile:
73+
input_json = list(json.load(infile))
74+
# 提取title
75+
for chapter_block in re.findall(r'<chapter>(.*?)</chapter>', input_response, flags=re.DOTALL):
76+
title = re.search(r'<title>(.*?)</title>', chapter_block, flags=re.DOTALL)
77+
if title:
78+
chapter_title = self._id_to_text(title.group(1).strip(), input_json, image_prefix)
79+
else:
80+
chapter_title = ""
81+
# 找出所有 qa_pair 块
82+
for pair in re.findall(r'<qa_pair>(.*?)</qa_pair>', chapter_block, flags=re.DOTALL):
83+
# 提取 question 部分
84+
q_match = re.search(r'<question>(.*?)</question>', pair, flags=re.DOTALL)
85+
# 提取 answer 部分
86+
a_match = re.search(r'<answer>(.*?)</answer>', pair, flags=re.DOTALL)
87+
# 提取solution部分
88+
s_match = re.search(r'<solution>(.*?)</solution>', pair, flags=re.DOTALL)
89+
# 提取label
90+
label_match = re.search(r'<label>(.*?)</label>', pair, flags=re.DOTALL)
91+
if not ((q_match and label_match) or (a_match and label_match) or (s_match and label_match)):
92+
continue
93+
label = label_match.group(1).strip()
94+
qa_list.append({
95+
'question': self._id_to_text(q_match.group(1).strip(), input_json, image_prefix) if q_match else "",
96+
'answer': a_match.group(1).strip() if a_match else "",
97+
'solution': self._id_to_text(s_match.group(1).strip(), input_json, image_prefix) if s_match else "",
98+
'label': label,
99+
'chapter_title': chapter_title
100+
})
101+
return qa_list
102+
103+
def run(self, storage: DataFlowStorage,
104+
input_response_path_key,
105+
input_converted_layout_path_key,
106+
input_name_key,
107+
output_qalist_path_key,
108+
):
109+
dataframe = storage.read("dataframe")
110+
111+
# Response 转换
112+
for idx, row in dataframe.iterrows():
113+
converted_json_path = row[input_converted_layout_path_key]
114+
response = Path(row[input_response_path_key]).read_text(encoding='utf-8')
115+
name = row[input_name_key]
116+
117+
image_prefix = os.path.join(name, f"{self.mode}_images")
118+
qa_list = self._convert_response(response, converted_json_path, image_prefix)
119+
output_qalist_path = os.path.join(self.output_dir, name, f"extracted_{self.mode}s.jsonl")
120+
os.makedirs(os.path.dirname(output_qalist_path), exist_ok=True)
121+
with open(output_qalist_path, 'w') as outfile:
122+
for qa in qa_list:
123+
json.dump(qa, outfile, ensure_ascii=False)
124+
outfile.write('\n')
125+
126+
# 复制图片
127+
src_dir = os.path.join(self.intermediate_dir, 'mineru', Path(converted_json_path).stem).replace('_content_list_converted','')
128+
src_images = os.path.join(src_dir, 'vlm', 'images')
129+
dst_images = os.path.join(self.output_dir, image_prefix)
130+
131+
try:
132+
if os.path.exists(src_images):
133+
shutil.copytree(src_images, dst_images)
134+
else:
135+
self.logger.warning(f"Source images dir does not exist: {src_images}")
136+
except Exception as e:
137+
self.logger.warning(f"Failed to copy images from {src_images} to {dst_images}: {e}")
138+
139+
dataframe.loc[idx, output_qalist_path_key] = output_qalist_path
140+
141+
storage.write(dataframe)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import json
2+
from dataflow.core import OperatorABC
3+
from dataflow.utils.registry import OPERATOR_REGISTRY
4+
from dataflow.utils.storage import DataFlowStorage
5+
6+
@OPERATOR_REGISTRY.register()
7+
class MinerU2LLMInputOperator(OperatorABC):
8+
def __init__(self):
9+
pass
10+
11+
@staticmethod
12+
def get_desc(lang: str = "zh") -> str:
13+
if lang == 'zh':
14+
return (
15+
"MinerU格式转换为LLM输入格式算子。"
16+
"将MinerU生成的内容列表JSON文件转换为适合LLM处理的格式,"
17+
"包括展平列表项并重新编号。"
18+
)
19+
else:
20+
return (
21+
"Convert MinerU format to LLM input format operator."
22+
"Transforms the content list JSON file generated by MinerU into a format suitable for LLM processing,"
23+
"including flattening list items and re-indexing."
24+
)
25+
26+
def _convert_json(self, input_file, output_file):
27+
with open(input_file, 'r') as infile:
28+
data = list(json.load(infile))
29+
30+
new_data = []
31+
id = 0
32+
for item in data:
33+
item['id'] = id
34+
item.pop('bbox', None)
35+
item.pop('page_idx', None)
36+
if item.get('type','') == 'list':
37+
if item['sub_type'] == 'text':
38+
for idx, list_item in enumerate(item.get('list_items', [])):
39+
new_item = {
40+
'type': 'text',
41+
'text': list_item,
42+
'id': id + idx,
43+
}
44+
new_data.append(new_item)
45+
id += len(item.get('list_items', []))
46+
else:
47+
new_data.append(item)
48+
id += 1
49+
50+
with open(output_file, 'w') as outfile:
51+
json.dump(new_data, outfile, ensure_ascii=False)
52+
53+
def run(self, storage: DataFlowStorage,
54+
input_markdown_path_key,
55+
output_converted_layout_key,
56+
):
57+
dataframe = storage.read("dataframe")
58+
59+
for index, row in dataframe.iterrows():
60+
input_json_path = row[input_markdown_path_key].replace('.md', '_content_list.json')
61+
converted_path = input_json_path.replace('.json', '_converted.json')
62+
self._convert_json(input_json_path, converted_path)
63+
dataframe.at[index, output_converted_layout_key] = converted_path
64+
65+
with open(converted_path, 'r') as infile:
66+
data = json.load(infile)
67+
assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}"
68+
69+
storage.write(dataframe)

0 commit comments

Comments
 (0)