Skip to content

Commit 13f6814

Browse files
committed
[PDF2VQA] 大幅度的重构,复用已有算子
1 parent 31f5841 commit 13f6814

File tree

9 files changed

+487
-564
lines changed

9 files changed

+487
-564
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "subject": "math", "output_dir": "../vqa_output_test/math1"}
2-
{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "subject": "math", "output_dir": "../vqa_output_test/math2"}
1+
{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "name": "math1"}
2+
{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "name": "math2"}

dataflow/operators/core_text/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING
22

33
if TYPE_CHECKING:
4-
from .generate.prompted_generator import PromptedGenerator
4+
from .generate.prompted_generator import PromptedGenerator, ChunkedPromptedGenerator
55
from .generate.format_str_prompted_generator import FormatStrPromptedGenerator
66
from .generate.random_domain_knowledge_row_generator import RandomDomainKnowledgeRowGenerator
77
from .generate.text2qa_generator import Text2QAGenerator

dataflow/operators/core_text/generate/prompted_generator.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pandas as pd
2+
import tiktoken
23
from dataflow.utils.registry import OPERATOR_REGISTRY
34
from dataflow import get_logger
5+
from pathlib import Path
46

57
from dataflow.utils.storage import DataFlowStorage
68
from dataflow.core import OperatorABC
@@ -87,3 +89,139 @@ def run(self, storage: DataFlowStorage, input_key: str = "raw_content", output_k
8789
# Save the updated dataframe to the output file
8890
output_file = storage.write(dataframe)
8991
return output_key
92+
class ChunkedPromptedGenerator(OperatorABC):
93+
"""
94+
基于Prompt的生成算子,支持自动chunk输入。
95+
- 使用tiktoken精确计算token数量;
96+
- 若输入超过max_chunk_len,采用递归二分法切分;
97+
- 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径;
98+
- 生成结果是以separator拼接的字符串。
99+
"""
100+
101+
def __init__(
102+
self,
103+
llm_serving: LLMServingABC,
104+
system_prompt: str = "You are a helpful agent.",
105+
json_schema: dict = None,
106+
max_chunk_len: int = 128000,
107+
enc = tiktoken.get_encoding("cl100k_base"),
108+
seperator: str = "\n",
109+
):
110+
self.logger = get_logger()
111+
self.llm_serving = llm_serving
112+
self.system_prompt = system_prompt
113+
self.json_schema = json_schema
114+
self.max_chunk_len = max_chunk_len
115+
self.enc = enc
116+
self.separator = seperator
117+
118+
@staticmethod
119+
def get_desc(lang: str = "zh"):
120+
if lang == "zh":
121+
return (
122+
"基于提示词的生成算子,支持长文本自动分chunk。"
123+
"采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。"
124+
"从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。"
125+
"多个生成结果以separator拼接成最终输出字符串。"
126+
"输入参数:\n"
127+
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
128+
"- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n"
129+
"- max_chunk_len:单个chunk的最大token长度,默认为128000\n"
130+
"- input_path_key:输入文件路径字段名,默认为'input_path'\n"
131+
"- output_path_key:输出文件路径字段名,默认为'output_path'\n"
132+
"- json_schema:可选,生成结果的JSON Schema约束\n"
133+
"- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n"
134+
"- separator:chunk结果拼接分隔符,默认为换行符\n"
135+
)
136+
else:
137+
return (
138+
"Prompt-based generator with recursive chunk splitting."
139+
"Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens."
140+
"Reads content from specified input file paths and saves generated results to designated output file paths."
141+
"Multiple generated results are joined as a string using the specified separator."
142+
"Input Parameters:\n"
143+
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
144+
"- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n"
145+
"- max_chunk_len: Maximum token length per chunk, default is 128000\n"
146+
"- input_path_key: Field name for input file path, default is 'input_path'\n"
147+
"- output_path_key: Field name for output file path, default is 'output_path'\n"
148+
"- json_schema: Optional JSON Schema constraint for generated results\n"
149+
"- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n"
150+
"- separator: Separator for chunk results, default is newline character\n"
151+
)
152+
153+
# === token计算 ===
154+
def _count_tokens(self, text: str) -> int:
155+
return len(self.enc.encode(text))
156+
157+
# === 递归二分分chunk ===
158+
def _split_recursive(self, text: str) -> list[str]:
159+
"""递归地将文本拆分为不超过max_chunk_len的多个chunk"""
160+
token_len = self._count_tokens(text)
161+
if token_len <= self.max_chunk_len:
162+
return [text]
163+
else:
164+
mid = len(text) // 2
165+
left, right = text[:mid], text[mid:]
166+
return self._split_recursive(left) + self._split_recursive(right)
167+
168+
def run(
169+
self,
170+
storage: DataFlowStorage,
171+
input_path_key,
172+
output_path_key,
173+
):
174+
self.logger.info("Running ChunkedPromptedGenerator...")
175+
dataframe = storage.read("dataframe")
176+
self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.")
177+
178+
all_generated_results = []
179+
180+
all_llm_inputs = []
181+
row_chunk_map = [] # 记录每个row对应的chunk数量
182+
183+
# === 先收集所有chunk ===
184+
for i, row in dataframe.iterrows():
185+
raw_content = Path(row[input_path_key]).read_text(encoding='utf-8')
186+
187+
chunks = self._split_recursive(raw_content)
188+
self.logger.info(f"Row {i}: split into {len(chunks)} chunks")
189+
190+
system_prompt = self.system_prompt + "\n"
191+
llm_inputs = [system_prompt + chunk for chunk in chunks]
192+
all_llm_inputs.extend(llm_inputs)
193+
row_chunk_map.append(len(chunks))
194+
195+
# === 一次性并发调用 ===
196+
self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate")
197+
198+
try:
199+
if self.json_schema:
200+
all_responses = self.llm_serving.generate_from_input(
201+
all_llm_inputs, json_schema=self.json_schema
202+
)
203+
else:
204+
all_responses = self.llm_serving.generate_from_input(all_llm_inputs)
205+
except Exception as e:
206+
self.logger.error(f"Global generation failed: {e}")
207+
all_generated_results = [[] for _ in range(len(dataframe))]
208+
else:
209+
# === 按row重新划分responses ===
210+
all_generated_results = []
211+
idx = 0
212+
for num_chunks in row_chunk_map:
213+
if num_chunks == 0:
214+
all_generated_results.append([])
215+
else:
216+
all_generated_results.append(all_responses[idx:idx + num_chunks])
217+
idx += num_chunks
218+
219+
for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results):
220+
output_path = row[input_path_key].split('.')[0] + '_llm_output.txt'
221+
with open(output_path, 'w', encoding='utf-8') as f:
222+
f.write(self.separator.join(gen_results))
223+
dataframe.at[i, output_path_key] = output_path
224+
225+
output_file = storage.write(dataframe)
226+
self.logger.info(f"Generation complete. Output saved to {output_file}")
227+
return output_path_key

dataflow/operators/pdf2vqa/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING
22

33
if TYPE_CHECKING:
4-
from .generate.vqa_extractor import VQAExtractor
4+
from .generate.pdf2vqa_formatter import MinerU2LLMInputOperator, LLMOutputParser, QA_Merger
55

66

77
else:

0 commit comments

Comments
 (0)