Skip to content

Commit 8427fbd

Browse files
committed
[pdf2vqa] 为chunked_prompted_generator设置单独文件。添加了一些注释
1 parent 13f6814 commit 8427fbd

File tree

4 files changed

+154
-138
lines changed

4 files changed

+154
-138
lines changed

dataflow/operators/core_text/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import TYPE_CHECKING
22

33
if TYPE_CHECKING:
4-
from .generate.prompted_generator import PromptedGenerator, ChunkedPromptedGenerator
4+
from .generate.prompted_generator import PromptedGenerator
5+
from .generate.chunked_prompted_generator import ChunkedPromptedGenerator
56
from .generate.format_str_prompted_generator import FormatStrPromptedGenerator
67
from .generate.random_domain_knowledge_row_generator import RandomDomainKnowledgeRowGenerator
78
from .generate.text2qa_generator import Text2QAGenerator
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import pandas as pd
2+
import tiktoken
3+
from dataflow.utils.registry import OPERATOR_REGISTRY
4+
from dataflow import get_logger
5+
from pathlib import Path
6+
7+
from dataflow.utils.storage import DataFlowStorage
8+
from dataflow.core import OperatorABC
9+
from dataflow.core import LLMServingABC
10+
11+
@OPERATOR_REGISTRY.register()
12+
class ChunkedPromptedGenerator(OperatorABC):
13+
"""
14+
基于Prompt的生成算子,支持自动chunk输入。
15+
- 使用tiktoken或HuggingFace的AutoTokenizer计算token数量;
16+
- 若输入超过max_chunk_len,采用递归二分法切分;
17+
- 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径;
18+
- 生成结果是以separator拼接的字符串。
19+
"""
20+
21+
def __init__(
22+
self,
23+
llm_serving: LLMServingABC,
24+
system_prompt: str = "You are a helpful agent.",
25+
json_schema: dict = None,
26+
max_chunk_len: int = 128000,
27+
enc = tiktoken.get_encoding("cl100k_base"), # 支持len(enc.encode(text))的tokenizer都可以,比如tiktoken或HuggingFace的AutoTokenizer
28+
seperator: str = "\n",
29+
):
30+
self.logger = get_logger()
31+
self.llm_serving = llm_serving
32+
self.system_prompt = system_prompt
33+
self.json_schema = json_schema
34+
self.max_chunk_len = max_chunk_len
35+
self.enc = enc
36+
self.separator = seperator
37+
38+
@staticmethod
39+
def get_desc(lang: str = "zh"):
40+
if lang == "zh":
41+
return (
42+
"基于提示词的生成算子,支持长文本自动分chunk。"
43+
"采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。"
44+
"从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。"
45+
"多个生成结果以separator拼接成最终输出字符串。"
46+
"输入参数:\n"
47+
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
48+
"- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n"
49+
"- max_chunk_len:单个chunk的最大token长度,默认为128000\n"
50+
"- input_path_key:输入文件路径字段名,默认为'input_path'\n"
51+
"- output_path_key:输出文件路径字段名,默认为'output_path'\n"
52+
"- json_schema:可选,生成结果的JSON Schema约束\n"
53+
"- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n"
54+
"- separator:chunk结果拼接分隔符,默认为换行符\n"
55+
)
56+
else:
57+
return (
58+
"Prompt-based generator with recursive chunk splitting."
59+
"Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens."
60+
"Reads content from specified input file paths and saves generated results to designated output file paths."
61+
"Multiple generated results are joined as a string using the specified separator."
62+
"Input Parameters:\n"
63+
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
64+
"- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n"
65+
"- max_chunk_len: Maximum token length per chunk, default is 128000\n"
66+
"- input_path_key: Field name for input file path, default is 'input_path'\n"
67+
"- output_path_key: Field name for output file path, default is 'output_path'\n"
68+
"- json_schema: Optional JSON Schema constraint for generated results\n"
69+
"- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n"
70+
"- separator: Separator for chunk results, default is newline character\n"
71+
)
72+
73+
# === token计算 ===
74+
def _count_tokens(self, text: str) -> int:
75+
return len(self.enc.encode(text))
76+
77+
# === 递归二分分chunk ===
78+
def _split_recursive(self, text: str) -> list[str]:
79+
"""递归地将文本拆分为不超过max_chunk_len的多个chunk"""
80+
token_len = self._count_tokens(text)
81+
if token_len <= self.max_chunk_len:
82+
return [text]
83+
else:
84+
mid = len(text) // 2
85+
left, right = text[:mid], text[mid:]
86+
return self._split_recursive(left) + self._split_recursive(right)
87+
88+
def run(
89+
self,
90+
storage: DataFlowStorage,
91+
input_path_key,
92+
output_path_key,
93+
):
94+
self.logger.info("Running ChunkedPromptedGenerator...")
95+
dataframe = storage.read("dataframe")
96+
self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.")
97+
98+
all_generated_results = []
99+
100+
all_llm_inputs = []
101+
row_chunk_map = [] # 记录每个row对应的chunk数量
102+
103+
# === 先收集所有chunk ===
104+
for i, row in dataframe.iterrows():
105+
raw_content = Path(row[input_path_key]).read_text(encoding='utf-8')
106+
107+
chunks = self._split_recursive(raw_content)
108+
self.logger.info(f"Row {i}: split into {len(chunks)} chunks")
109+
110+
system_prompt = self.system_prompt + "\n"
111+
llm_inputs = [system_prompt + chunk for chunk in chunks]
112+
all_llm_inputs.extend(llm_inputs)
113+
row_chunk_map.append(len(chunks))
114+
115+
# === 一次性并发调用 ===
116+
self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate")
117+
118+
try:
119+
if self.json_schema:
120+
all_responses = self.llm_serving.generate_from_input(
121+
all_llm_inputs, json_schema=self.json_schema
122+
)
123+
else:
124+
all_responses = self.llm_serving.generate_from_input(all_llm_inputs)
125+
except Exception as e:
126+
self.logger.error(f"Global generation failed: {e}")
127+
all_generated_results = [[] for _ in range(len(dataframe))]
128+
else:
129+
# === 按row重新划分responses ===
130+
all_generated_results = []
131+
idx = 0
132+
for num_chunks in row_chunk_map:
133+
if num_chunks == 0:
134+
all_generated_results.append([])
135+
else:
136+
all_generated_results.append(all_responses[idx:idx + num_chunks])
137+
idx += num_chunks
138+
139+
for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results):
140+
output_path = row[input_path_key].split('.')[0] + '_llm_output.txt'
141+
with open(output_path, 'w', encoding='utf-8') as f:
142+
f.write(self.separator.join(gen_results))
143+
dataframe.at[i, output_path_key] = output_path
144+
145+
output_file = storage.write(dataframe)
146+
self.logger.info(f"Generation complete. Output saved to {output_file}")
147+
return output_path_key
Lines changed: 0 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pandas as pd
2-
import tiktoken
32
from dataflow.utils.registry import OPERATOR_REGISTRY
43
from dataflow import get_logger
54
from pathlib import Path
@@ -89,139 +88,3 @@ def run(self, storage: DataFlowStorage, input_key: str = "raw_content", output_k
8988
# Save the updated dataframe to the output file
9089
output_file = storage.write(dataframe)
9190
return output_key
92-
class ChunkedPromptedGenerator(OperatorABC):
93-
"""
94-
基于Prompt的生成算子,支持自动chunk输入。
95-
- 使用tiktoken精确计算token数量;
96-
- 若输入超过max_chunk_len,采用递归二分法切分;
97-
- 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径;
98-
- 生成结果是以separator拼接的字符串。
99-
"""
100-
101-
def __init__(
102-
self,
103-
llm_serving: LLMServingABC,
104-
system_prompt: str = "You are a helpful agent.",
105-
json_schema: dict = None,
106-
max_chunk_len: int = 128000,
107-
enc = tiktoken.get_encoding("cl100k_base"),
108-
seperator: str = "\n",
109-
):
110-
self.logger = get_logger()
111-
self.llm_serving = llm_serving
112-
self.system_prompt = system_prompt
113-
self.json_schema = json_schema
114-
self.max_chunk_len = max_chunk_len
115-
self.enc = enc
116-
self.separator = seperator
117-
118-
@staticmethod
119-
def get_desc(lang: str = "zh"):
120-
if lang == "zh":
121-
return (
122-
"基于提示词的生成算子,支持长文本自动分chunk。"
123-
"采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。"
124-
"从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。"
125-
"多个生成结果以separator拼接成最终输出字符串。"
126-
"输入参数:\n"
127-
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
128-
"- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n"
129-
"- max_chunk_len:单个chunk的最大token长度,默认为128000\n"
130-
"- input_path_key:输入文件路径字段名,默认为'input_path'\n"
131-
"- output_path_key:输出文件路径字段名,默认为'output_path'\n"
132-
"- json_schema:可选,生成结果的JSON Schema约束\n"
133-
"- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n"
134-
"- separator:chunk结果拼接分隔符,默认为换行符\n"
135-
)
136-
else:
137-
return (
138-
"Prompt-based generator with recursive chunk splitting."
139-
"Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens."
140-
"Reads content from specified input file paths and saves generated results to designated output file paths."
141-
"Multiple generated results are joined as a string using the specified separator."
142-
"Input Parameters:\n"
143-
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
144-
"- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n"
145-
"- max_chunk_len: Maximum token length per chunk, default is 128000\n"
146-
"- input_path_key: Field name for input file path, default is 'input_path'\n"
147-
"- output_path_key: Field name for output file path, default is 'output_path'\n"
148-
"- json_schema: Optional JSON Schema constraint for generated results\n"
149-
"- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n"
150-
"- separator: Separator for chunk results, default is newline character\n"
151-
)
152-
153-
# === token计算 ===
154-
def _count_tokens(self, text: str) -> int:
155-
return len(self.enc.encode(text))
156-
157-
# === 递归二分分chunk ===
158-
def _split_recursive(self, text: str) -> list[str]:
159-
"""递归地将文本拆分为不超过max_chunk_len的多个chunk"""
160-
token_len = self._count_tokens(text)
161-
if token_len <= self.max_chunk_len:
162-
return [text]
163-
else:
164-
mid = len(text) // 2
165-
left, right = text[:mid], text[mid:]
166-
return self._split_recursive(left) + self._split_recursive(right)
167-
168-
def run(
169-
self,
170-
storage: DataFlowStorage,
171-
input_path_key,
172-
output_path_key,
173-
):
174-
self.logger.info("Running ChunkedPromptedGenerator...")
175-
dataframe = storage.read("dataframe")
176-
self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.")
177-
178-
all_generated_results = []
179-
180-
all_llm_inputs = []
181-
row_chunk_map = [] # 记录每个row对应的chunk数量
182-
183-
# === 先收集所有chunk ===
184-
for i, row in dataframe.iterrows():
185-
raw_content = Path(row[input_path_key]).read_text(encoding='utf-8')
186-
187-
chunks = self._split_recursive(raw_content)
188-
self.logger.info(f"Row {i}: split into {len(chunks)} chunks")
189-
190-
system_prompt = self.system_prompt + "\n"
191-
llm_inputs = [system_prompt + chunk for chunk in chunks]
192-
all_llm_inputs.extend(llm_inputs)
193-
row_chunk_map.append(len(chunks))
194-
195-
# === 一次性并发调用 ===
196-
self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate")
197-
198-
try:
199-
if self.json_schema:
200-
all_responses = self.llm_serving.generate_from_input(
201-
all_llm_inputs, json_schema=self.json_schema
202-
)
203-
else:
204-
all_responses = self.llm_serving.generate_from_input(all_llm_inputs)
205-
except Exception as e:
206-
self.logger.error(f"Global generation failed: {e}")
207-
all_generated_results = [[] for _ in range(len(dataframe))]
208-
else:
209-
# === 按row重新划分responses ===
210-
all_generated_results = []
211-
idx = 0
212-
for num_chunks in row_chunk_map:
213-
if num_chunks == 0:
214-
all_generated_results.append([])
215-
else:
216-
all_generated_results.append(all_responses[idx:idx + num_chunks])
217-
idx += num_chunks
218-
219-
for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results):
220-
output_path = row[input_path_key].split('.')[0] + '_llm_output.txt'
221-
with open(output_path, 'w', encoding='utf-8') as f:
222-
f.write(self.separator.join(gen_results))
223-
dataframe.at[i, output_path_key] = output_path
224-
225-
output_file = storage.write(dataframe)
226-
self.logger.info(f"Generation complete. Output saved to {output_file}")
227-
return output_path_key

dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def __init__(self):
3838
self.llm_output_answer_parser = LLMOutputParser(mode="answer", output_dir="./cache", intermediate_dir="intermediate")
3939
self.qa_merger = QA_Merger(output_dir="./cache", strict_title_match=False)
4040
def forward(self):
41+
# 目前的处理逻辑是:MinerU处理问题-MinerU处理答案-格式化问题文本-格式化答案文本-问题文本输入LLM-答案文本输入LLM-解析问题输出-解析答案输出-合并问答对
42+
# 由于问答对可能来自同一份pdf,也有可能来自不同pdf,而dataflow目前不支持分支,因此这里只能将question和answer的pdf都进行一次处理,
43+
# 即使是同一份pdf也会被处理两次,最后再合并问答对。
44+
# 未来会再思考如何优化这个流程,避免重复处理同一份pdf,提升性能。
45+
4146
self.mineru_executor.run(
4247
storage=self.storage.step(),
4348
input_key="question_pdf_path",

0 commit comments

Comments
 (0)