Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "subject": "math", "output_dir": "../vqa_output_test/math1"}
{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "subject": "math", "output_dir": "../vqa_output_test/math2"}
{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "name": "math1"}
{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "name": "math2"}
1 change: 1 addition & 0 deletions dataflow/operators/core_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

if TYPE_CHECKING:
from .generate.prompted_generator import PromptedGenerator
from .generate.chunked_prompted_generator import ChunkedPromptedGenerator
from .generate.format_str_prompted_generator import FormatStrPromptedGenerator
from .generate.random_domain_knowledge_row_generator import RandomDomainKnowledgeRowGenerator
from .generate.text2qa_generator import Text2QAGenerator
Expand Down
147 changes: 147 additions & 0 deletions dataflow/operators/core_text/generate/chunked_prompted_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import pandas as pd
import tiktoken
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from pathlib import Path

from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC

@OPERATOR_REGISTRY.register()
class ChunkedPromptedGenerator(OperatorABC):
"""
基于Prompt的生成算子,支持自动chunk输入。
- 使用tiktoken或HuggingFace的AutoTokenizer计算token数量;
- 若输入超过max_chunk_len,采用递归二分法切分;
- 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径;
- 生成结果是以separator拼接的字符串。
"""

def __init__(
self,
llm_serving: LLMServingABC,
system_prompt: str = "You are a helpful agent.",
json_schema: dict = None,
max_chunk_len: int = 128000,
enc = tiktoken.get_encoding("cl100k_base"), # 支持len(enc.encode(text))的tokenizer都可以,比如tiktoken或HuggingFace的AutoTokenizer
seperator: str = "\n",
):
self.logger = get_logger()
self.llm_serving = llm_serving
self.system_prompt = system_prompt
self.json_schema = json_schema
self.max_chunk_len = max_chunk_len
self.enc = enc
self.separator = seperator

@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"基于提示词的生成算子,支持长文本自动分chunk。"
"采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。"
"从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。"
"多个生成结果以separator拼接成最终输出字符串。"
"输入参数:\n"
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
"- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n"
"- max_chunk_len:单个chunk的最大token长度,默认为128000\n"
"- input_path_key:输入文件路径字段名,默认为'input_path'\n"
"- output_path_key:输出文件路径字段名,默认为'output_path'\n"
"- json_schema:可选,生成结果的JSON Schema约束\n"
"- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n"
"- separator:chunk结果拼接分隔符,默认为换行符\n"
)
else:
return (
"Prompt-based generator with recursive chunk splitting."
"Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens."
"Reads content from specified input file paths and saves generated results to designated output file paths."
"Multiple generated results are joined as a string using the specified separator."
"Input Parameters:\n"
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
"- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n"
"- max_chunk_len: Maximum token length per chunk, default is 128000\n"
"- input_path_key: Field name for input file path, default is 'input_path'\n"
"- output_path_key: Field name for output file path, default is 'output_path'\n"
"- json_schema: Optional JSON Schema constraint for generated results\n"
"- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n"
"- separator: Separator for chunk results, default is newline character\n"
)

# === token计算 ===
def _count_tokens(self, text: str) -> int:
return len(self.enc.encode(text))

# === 递归二分分chunk ===
def _split_recursive(self, text: str) -> list[str]:
"""递归地将文本拆分为不超过max_chunk_len的多个chunk"""
token_len = self._count_tokens(text)
if token_len <= self.max_chunk_len:
return [text]
else:
mid = len(text) // 2
left, right = text[:mid], text[mid:]
return self._split_recursive(left) + self._split_recursive(right)

def run(
self,
storage: DataFlowStorage,
input_path_key,
output_path_key,
):
self.logger.info("Running ChunkedPromptedGenerator...")
dataframe = storage.read("dataframe")
self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.")

all_generated_results = []

all_llm_inputs = []
row_chunk_map = [] # 记录每个row对应的chunk数量

# === 先收集所有chunk ===
for i, row in dataframe.iterrows():
raw_content = Path(row[input_path_key]).read_text(encoding='utf-8')

chunks = self._split_recursive(raw_content)
self.logger.info(f"Row {i}: split into {len(chunks)} chunks")

system_prompt = self.system_prompt + "\n"
llm_inputs = [system_prompt + chunk for chunk in chunks]
all_llm_inputs.extend(llm_inputs)
row_chunk_map.append(len(chunks))

# === 一次性并发调用 ===
self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate")

try:
if self.json_schema:
all_responses = self.llm_serving.generate_from_input(
all_llm_inputs, json_schema=self.json_schema
)
else:
all_responses = self.llm_serving.generate_from_input(all_llm_inputs)
except Exception as e:
self.logger.error(f"Global generation failed: {e}")
all_generated_results = [[] for _ in range(len(dataframe))]
else:
# === 按row重新划分responses ===
all_generated_results = []
idx = 0
for num_chunks in row_chunk_map:
if num_chunks == 0:
all_generated_results.append([])
else:
all_generated_results.append(all_responses[idx:idx + num_chunks])
idx += num_chunks

for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results):
output_path = row[input_path_key].split('.')[0] + '_llm_output.txt'
with open(output_path, 'w', encoding='utf-8') as f:
f.write(self.separator.join(gen_results))
dataframe.at[i, output_path_key] = output_path

output_file = storage.write(dataframe)
self.logger.info(f"Generation complete. Output saved to {output_file}")
return output_path_key
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from pathlib import Path

from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
Expand Down
4 changes: 3 additions & 1 deletion dataflow/operators/pdf2vqa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .generate.vqa_extractor import VQAExtractor
from .generate.mineru_to_llm_input_operator import MinerU2LLMInputOperator
from .generate.llm_output_parser import LLMOutputParser
from .generate.qa_merger import QA_Merger


else:
Expand Down
141 changes: 141 additions & 0 deletions dataflow/operators/pdf2vqa/generate/llm_output_parser.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个算子是来转格式,如果作为算子存在,也遵循我们的算子命名规矩吧,比如文件名叫mineru_to_llm_formatter,类名一样但是驼峰

Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import json
import re
import shutil
from pathlib import Path
from typing import Literal
from dataflow.core import OperatorABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage
from dataflow import get_logger

@OPERATOR_REGISTRY.register()
class LLMOutputParser(OperatorABC):
def __init__(self,
mode: Literal['question', 'answer'],
output_dir,
intermediate_dir: str = "intermediate",
):
self.logger = get_logger()
self.mode = mode
self.output_dir = output_dir
self.intermediate_dir = intermediate_dir

@staticmethod
def get_desc(lang: str = "zh") -> str:
if lang == 'zh':
return (
"LLM输出解析算子。"
"将LLM生成的包含题目和答案ID的响应文本,"
"转换为结构化的QA列表,并复制相关图片到输出目录。"
)
else:
return (
"LLM output parsing operator."
"Converts LLM-generated response text containing question and answer IDs"
"into a structured QA list and copies related images to the output directory."
)

def _id_to_text(self, input_ids, input_json, image_prefix="images"):
texts = []
id_list = input_ids.replace(' ', '').split(',')
for id in id_list:
try:
int(id)
except:
continue
if int(id) < len(input_json):
try:
item = input_json[int(id)]
except:
continue
if 'text' in item:
texts.append(item['text'])
elif 'img_path' in item:
try:
img_path = item.get('img_path', '')
img_name = os.path.basename(img_path)
new_path = f"{image_prefix}/{img_name}"
texts.append(f"![{' '.join(item.get('image_caption','image'))}]({new_path})")
except:
pass
elif item.get('type','') == 'list':
if item['sub_type'] == 'text':
try:
texts.append(input_json[int(id)]['list_items'].pop(0))
except:
pass
return '\n'.join(texts)

def _convert_response(self, input_response, input_json_path, image_prefix="images"):
qa_list = []
with open(input_json_path, 'r') as infile:
input_json = list(json.load(infile))
# 提取title
for chapter_block in re.findall(r'<chapter>(.*?)</chapter>', input_response, flags=re.DOTALL):
title = re.search(r'<title>(.*?)</title>', chapter_block, flags=re.DOTALL)
if title:
chapter_title = self._id_to_text(title.group(1).strip(), input_json, image_prefix)
else:
chapter_title = ""
# 找出所有 qa_pair 块
for pair in re.findall(r'<qa_pair>(.*?)</qa_pair>', chapter_block, flags=re.DOTALL):
# 提取 question 部分
q_match = re.search(r'<question>(.*?)</question>', pair, flags=re.DOTALL)
# 提取 answer 部分
a_match = re.search(r'<answer>(.*?)</answer>', pair, flags=re.DOTALL)
# 提取solution部分
s_match = re.search(r'<solution>(.*?)</solution>', pair, flags=re.DOTALL)
# 提取label
label_match = re.search(r'<label>(.*?)</label>', pair, flags=re.DOTALL)
if not ((q_match and label_match) or (a_match and label_match) or (s_match and label_match)):
continue
label = label_match.group(1).strip()
qa_list.append({
'question': self._id_to_text(q_match.group(1).strip(), input_json, image_prefix) if q_match else "",
'answer': a_match.group(1).strip() if a_match else "",
'solution': self._id_to_text(s_match.group(1).strip(), input_json, image_prefix) if s_match else "",
'label': label,
'chapter_title': chapter_title
})
return qa_list

def run(self, storage: DataFlowStorage,
input_response_path_key,
input_converted_layout_path_key,
input_name_key,
output_qalist_path_key,
):
dataframe = storage.read("dataframe")

# Response 转换
for idx, row in dataframe.iterrows():
converted_json_path = row[input_converted_layout_path_key]
response = Path(row[input_response_path_key]).read_text(encoding='utf-8')
name = row[input_name_key]

image_prefix = os.path.join(name, f"{self.mode}_images")
qa_list = self._convert_response(response, converted_json_path, image_prefix)
output_qalist_path = os.path.join(self.output_dir, name, f"extracted_{self.mode}s.jsonl")
os.makedirs(os.path.dirname(output_qalist_path), exist_ok=True)
with open(output_qalist_path, 'w') as outfile:
for qa in qa_list:
json.dump(qa, outfile, ensure_ascii=False)
outfile.write('\n')

# 复制图片
src_dir = os.path.join(self.intermediate_dir, 'mineru', Path(converted_json_path).stem).replace('_content_list_converted','')
src_images = os.path.join(src_dir, 'vlm', 'images')
dst_images = os.path.join(self.output_dir, image_prefix)

try:
if os.path.exists(src_images):
shutil.copytree(src_images, dst_images)
else:
self.logger.warning(f"Source images dir does not exist: {src_images}")
except Exception as e:
self.logger.warning(f"Failed to copy images from {src_images} to {dst_images}: {e}")

dataframe.loc[idx, output_qalist_path_key] = output_qalist_path

storage.write(dataframe)
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
from dataflow.core import OperatorABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage

@OPERATOR_REGISTRY.register()
class MinerU2LLMInputOperator(OperatorABC):
def __init__(self):
pass

@staticmethod
def get_desc(lang: str = "zh") -> str:
if lang == 'zh':
return (
"MinerU格式转换为LLM输入格式算子。"
"将MinerU生成的内容列表JSON文件转换为适合LLM处理的格式,"
"包括展平列表项并重新编号。"
)
else:
return (
"Convert MinerU format to LLM input format operator."
"Transforms the content list JSON file generated by MinerU into a format suitable for LLM processing,"
"including flattening list items and re-indexing."
)

def _convert_json(self, input_file, output_file):
with open(input_file, 'r') as infile:
data = list(json.load(infile))

new_data = []
id = 0
for item in data:
item['id'] = id
item.pop('bbox', None)
item.pop('page_idx', None)
if item.get('type','') == 'list':
if item['sub_type'] == 'text':
for idx, list_item in enumerate(item.get('list_items', [])):
new_item = {
'type': 'text',
'text': list_item,
'id': id + idx,
}
new_data.append(new_item)
id += len(item.get('list_items', []))
else:
new_data.append(item)
id += 1

with open(output_file, 'w') as outfile:
json.dump(new_data, outfile, ensure_ascii=False)

def run(self, storage: DataFlowStorage,
input_markdown_path_key,
output_converted_layout_key,
):
dataframe = storage.read("dataframe")

for index, row in dataframe.iterrows():
input_json_path = row[input_markdown_path_key].replace('.md', '_content_list.json')
converted_path = input_json_path.replace('.json', '_converted.json')
self._convert_json(input_json_path, converted_path)
dataframe.at[index, output_converted_layout_key] = converted_path

with open(converted_path, 'r') as infile:
data = json.load(infile)
assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}"

storage.write(dataframe)
Loading