diff --git a/dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl b/dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl
index 40866a8d..d535d8ae 100644
--- a/dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl
+++ b/dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl
@@ -1,2 +1,2 @@
-{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "subject": "math", "output_dir": "../vqa_output_test/math1"}
-{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "subject": "math", "output_dir": "../vqa_output_test/math2"}
+{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "name": "math1"}
+{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "name": "math2"}
diff --git a/dataflow/operators/core_text/__init__.py b/dataflow/operators/core_text/__init__.py
index 592469c9..c9429092 100644
--- a/dataflow/operators/core_text/__init__.py
+++ b/dataflow/operators/core_text/__init__.py
@@ -2,6 +2,7 @@
if TYPE_CHECKING:
from .generate.prompted_generator import PromptedGenerator
+ from .generate.chunked_prompted_generator import ChunkedPromptedGenerator
from .generate.format_str_prompted_generator import FormatStrPromptedGenerator
from .generate.random_domain_knowledge_row_generator import RandomDomainKnowledgeRowGenerator
from .generate.text2qa_generator import Text2QAGenerator
diff --git a/dataflow/operators/core_text/generate/chunked_prompted_generator.py b/dataflow/operators/core_text/generate/chunked_prompted_generator.py
new file mode 100644
index 00000000..d54c5058
--- /dev/null
+++ b/dataflow/operators/core_text/generate/chunked_prompted_generator.py
@@ -0,0 +1,147 @@
+import pandas as pd
+import tiktoken
+from dataflow.utils.registry import OPERATOR_REGISTRY
+from dataflow import get_logger
+from pathlib import Path
+
+from dataflow.utils.storage import DataFlowStorage
+from dataflow.core import OperatorABC
+from dataflow.core import LLMServingABC
+
+@OPERATOR_REGISTRY.register()
+class ChunkedPromptedGenerator(OperatorABC):
+ """
+ 基于Prompt的生成算子,支持自动chunk输入。
+ - 使用tiktoken或HuggingFace的AutoTokenizer计算token数量;
+ - 若输入超过max_chunk_len,采用递归二分法切分;
+ - 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径;
+ - 生成结果是以separator拼接的字符串。
+ """
+
+ def __init__(
+ self,
+ llm_serving: LLMServingABC,
+ system_prompt: str = "You are a helpful agent.",
+ json_schema: dict = None,
+ max_chunk_len: int = 128000,
+ enc = tiktoken.get_encoding("cl100k_base"), # 支持len(enc.encode(text))的tokenizer都可以,比如tiktoken或HuggingFace的AutoTokenizer
+ seperator: str = "\n",
+ ):
+ self.logger = get_logger()
+ self.llm_serving = llm_serving
+ self.system_prompt = system_prompt
+ self.json_schema = json_schema
+ self.max_chunk_len = max_chunk_len
+ self.enc = enc
+ self.separator = seperator
+
+ @staticmethod
+ def get_desc(lang: str = "zh"):
+ if lang == "zh":
+ return (
+ "基于提示词的生成算子,支持长文本自动分chunk。"
+ "采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。"
+ "从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。"
+ "多个生成结果以separator拼接成最终输出字符串。"
+ "输入参数:\n"
+ "- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
+ "- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n"
+ "- max_chunk_len:单个chunk的最大token长度,默认为128000\n"
+ "- input_path_key:输入文件路径字段名,默认为'input_path'\n"
+ "- output_path_key:输出文件路径字段名,默认为'output_path'\n"
+ "- json_schema:可选,生成结果的JSON Schema约束\n"
+ "- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n"
+ "- separator:chunk结果拼接分隔符,默认为换行符\n"
+ )
+ else:
+ return (
+ "Prompt-based generator with recursive chunk splitting."
+ "Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens."
+ "Reads content from specified input file paths and saves generated results to designated output file paths."
+ "Multiple generated results are joined as a string using the specified separator."
+ "Input Parameters:\n"
+ "- llm_serving: LLM serving object implementing LLMServingABC interface\n"
+ "- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n"
+ "- max_chunk_len: Maximum token length per chunk, default is 128000\n"
+ "- input_path_key: Field name for input file path, default is 'input_path'\n"
+ "- output_path_key: Field name for output file path, default is 'output_path'\n"
+ "- json_schema: Optional JSON Schema constraint for generated results\n"
+ "- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n"
+ "- separator: Separator for chunk results, default is newline character\n"
+ )
+
+ # === token计算 ===
+ def _count_tokens(self, text: str) -> int:
+ return len(self.enc.encode(text))
+
+ # === 递归二分分chunk ===
+ def _split_recursive(self, text: str) -> list[str]:
+ """递归地将文本拆分为不超过max_chunk_len的多个chunk"""
+ token_len = self._count_tokens(text)
+ if token_len <= self.max_chunk_len:
+ return [text]
+ else:
+ mid = len(text) // 2
+ left, right = text[:mid], text[mid:]
+ return self._split_recursive(left) + self._split_recursive(right)
+
+ def run(
+ self,
+ storage: DataFlowStorage,
+ input_path_key,
+ output_path_key,
+ ):
+ self.logger.info("Running ChunkedPromptedGenerator...")
+ dataframe = storage.read("dataframe")
+ self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.")
+
+ all_generated_results = []
+
+ all_llm_inputs = []
+ row_chunk_map = [] # 记录每个row对应的chunk数量
+
+ # === 先收集所有chunk ===
+ for i, row in dataframe.iterrows():
+ raw_content = Path(row[input_path_key]).read_text(encoding='utf-8')
+
+ chunks = self._split_recursive(raw_content)
+ self.logger.info(f"Row {i}: split into {len(chunks)} chunks")
+
+ system_prompt = self.system_prompt + "\n"
+ llm_inputs = [system_prompt + chunk for chunk in chunks]
+ all_llm_inputs.extend(llm_inputs)
+ row_chunk_map.append(len(chunks))
+
+ # === 一次性并发调用 ===
+ self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate")
+
+ try:
+ if self.json_schema:
+ all_responses = self.llm_serving.generate_from_input(
+ all_llm_inputs, json_schema=self.json_schema
+ )
+ else:
+ all_responses = self.llm_serving.generate_from_input(all_llm_inputs)
+ except Exception as e:
+ self.logger.error(f"Global generation failed: {e}")
+ all_generated_results = [[] for _ in range(len(dataframe))]
+ else:
+ # === 按row重新划分responses ===
+ all_generated_results = []
+ idx = 0
+ for num_chunks in row_chunk_map:
+ if num_chunks == 0:
+ all_generated_results.append([])
+ else:
+ all_generated_results.append(all_responses[idx:idx + num_chunks])
+ idx += num_chunks
+
+ for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results):
+ output_path = row[input_path_key].split('.')[0] + '_llm_output.txt'
+ with open(output_path, 'w', encoding='utf-8') as f:
+ f.write(self.separator.join(gen_results))
+ dataframe.at[i, output_path_key] = output_path
+
+ output_file = storage.write(dataframe)
+ self.logger.info(f"Generation complete. Output saved to {output_file}")
+ return output_path_key
\ No newline at end of file
diff --git a/dataflow/operators/core_text/generate/prompted_generator.py b/dataflow/operators/core_text/generate/prompted_generator.py
index cc5ece8f..9cf26a80 100644
--- a/dataflow/operators/core_text/generate/prompted_generator.py
+++ b/dataflow/operators/core_text/generate/prompted_generator.py
@@ -1,6 +1,7 @@
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
+from pathlib import Path
from dataflow.utils.storage import DataFlowStorage
from dataflow.core import OperatorABC
diff --git a/dataflow/operators/pdf2vqa/__init__.py b/dataflow/operators/pdf2vqa/__init__.py
index 8236e494..a385917b 100644
--- a/dataflow/operators/pdf2vqa/__init__.py
+++ b/dataflow/operators/pdf2vqa/__init__.py
@@ -1,7 +1,9 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from .generate.vqa_extractor import VQAExtractor
+ from .generate.mineru_to_llm_input_operator import MinerU2LLMInputOperator
+ from .generate.llm_output_parser import LLMOutputParser
+ from .generate.qa_merger import QA_Merger
else:
diff --git a/dataflow/operators/pdf2vqa/generate/llm_output_parser.py b/dataflow/operators/pdf2vqa/generate/llm_output_parser.py
new file mode 100644
index 00000000..7565db78
--- /dev/null
+++ b/dataflow/operators/pdf2vqa/generate/llm_output_parser.py
@@ -0,0 +1,141 @@
+import os
+import json
+import re
+import shutil
+from pathlib import Path
+from typing import Literal
+from dataflow.core import OperatorABC
+from dataflow.utils.registry import OPERATOR_REGISTRY
+from dataflow.utils.storage import DataFlowStorage
+from dataflow import get_logger
+
+@OPERATOR_REGISTRY.register()
+class LLMOutputParser(OperatorABC):
+ def __init__(self,
+ mode: Literal['question', 'answer'],
+ output_dir,
+ intermediate_dir: str = "intermediate",
+ ):
+ self.logger = get_logger()
+ self.mode = mode
+ self.output_dir = output_dir
+ self.intermediate_dir = intermediate_dir
+
+ @staticmethod
+ def get_desc(lang: str = "zh") -> str:
+ if lang == 'zh':
+ return (
+ "LLM输出解析算子。"
+ "将LLM生成的包含题目和答案ID的响应文本,"
+ "转换为结构化的QA列表,并复制相关图片到输出目录。"
+ )
+ else:
+ return (
+ "LLM output parsing operator."
+ "Converts LLM-generated response text containing question and answer IDs"
+ "into a structured QA list and copies related images to the output directory."
+ )
+
+ def _id_to_text(self, input_ids, input_json, image_prefix="images"):
+ texts = []
+ id_list = input_ids.replace(' ', '').split(',')
+ for id in id_list:
+ try:
+ int(id)
+ except:
+ continue
+ if int(id) < len(input_json):
+ try:
+ item = input_json[int(id)]
+ except:
+ continue
+ if 'text' in item:
+ texts.append(item['text'])
+ elif 'img_path' in item:
+ try:
+ img_path = item.get('img_path', '')
+ img_name = os.path.basename(img_path)
+ new_path = f"{image_prefix}/{img_name}"
+ texts.append(f"")
+ except:
+ pass
+ elif item.get('type','') == 'list':
+ if item['sub_type'] == 'text':
+ try:
+ texts.append(input_json[int(id)]['list_items'].pop(0))
+ except:
+ pass
+ return '\n'.join(texts)
+
+ def _convert_response(self, input_response, input_json_path, image_prefix="images"):
+ qa_list = []
+ with open(input_json_path, 'r') as infile:
+ input_json = list(json.load(infile))
+ # 提取title
+ for chapter_block in re.findall(r'(.*?)', input_response, flags=re.DOTALL):
+ title = re.search(r'
(.*?)', chapter_block, flags=re.DOTALL)
+ if title:
+ chapter_title = self._id_to_text(title.group(1).strip(), input_json, image_prefix)
+ else:
+ chapter_title = ""
+ # 找出所有 qa_pair 块
+ for pair in re.findall(r'(.*?)', chapter_block, flags=re.DOTALL):
+ # 提取 question 部分
+ q_match = re.search(r'(.*?)', pair, flags=re.DOTALL)
+ # 提取 answer 部分
+ a_match = re.search(r'(.*?)', pair, flags=re.DOTALL)
+ # 提取solution部分
+ s_match = re.search(r'(.*?)', pair, flags=re.DOTALL)
+ # 提取label
+ label_match = re.search(r'', pair, flags=re.DOTALL)
+ if not ((q_match and label_match) or (a_match and label_match) or (s_match and label_match)):
+ continue
+ label = label_match.group(1).strip()
+ qa_list.append({
+ 'question': self._id_to_text(q_match.group(1).strip(), input_json, image_prefix) if q_match else "",
+ 'answer': a_match.group(1).strip() if a_match else "",
+ 'solution': self._id_to_text(s_match.group(1).strip(), input_json, image_prefix) if s_match else "",
+ 'label': label,
+ 'chapter_title': chapter_title
+ })
+ return qa_list
+
+ def run(self, storage: DataFlowStorage,
+ input_response_path_key,
+ input_converted_layout_path_key,
+ input_name_key,
+ output_qalist_path_key,
+ ):
+ dataframe = storage.read("dataframe")
+
+ # Response 转换
+ for idx, row in dataframe.iterrows():
+ converted_json_path = row[input_converted_layout_path_key]
+ response = Path(row[input_response_path_key]).read_text(encoding='utf-8')
+ name = row[input_name_key]
+
+ image_prefix = os.path.join(name, f"{self.mode}_images")
+ qa_list = self._convert_response(response, converted_json_path, image_prefix)
+ output_qalist_path = os.path.join(self.output_dir, name, f"extracted_{self.mode}s.jsonl")
+ os.makedirs(os.path.dirname(output_qalist_path), exist_ok=True)
+ with open(output_qalist_path, 'w') as outfile:
+ for qa in qa_list:
+ json.dump(qa, outfile, ensure_ascii=False)
+ outfile.write('\n')
+
+ # 复制图片
+ src_dir = os.path.join(self.intermediate_dir, 'mineru', Path(converted_json_path).stem).replace('_content_list_converted','')
+ src_images = os.path.join(src_dir, 'vlm', 'images')
+ dst_images = os.path.join(self.output_dir, image_prefix)
+
+ try:
+ if os.path.exists(src_images):
+ shutil.copytree(src_images, dst_images)
+ else:
+ self.logger.warning(f"Source images dir does not exist: {src_images}")
+ except Exception as e:
+ self.logger.warning(f"Failed to copy images from {src_images} to {dst_images}: {e}")
+
+ dataframe.loc[idx, output_qalist_path_key] = output_qalist_path
+
+ storage.write(dataframe)
\ No newline at end of file
diff --git a/dataflow/operators/pdf2vqa/generate/mineru_to_llm_input_operator.py b/dataflow/operators/pdf2vqa/generate/mineru_to_llm_input_operator.py
new file mode 100644
index 00000000..2a51d928
--- /dev/null
+++ b/dataflow/operators/pdf2vqa/generate/mineru_to_llm_input_operator.py
@@ -0,0 +1,69 @@
+import json
+from dataflow.core import OperatorABC
+from dataflow.utils.registry import OPERATOR_REGISTRY
+from dataflow.utils.storage import DataFlowStorage
+
+@OPERATOR_REGISTRY.register()
+class MinerU2LLMInputOperator(OperatorABC):
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def get_desc(lang: str = "zh") -> str:
+ if lang == 'zh':
+ return (
+ "MinerU格式转换为LLM输入格式算子。"
+ "将MinerU生成的内容列表JSON文件转换为适合LLM处理的格式,"
+ "包括展平列表项并重新编号。"
+ )
+ else:
+ return (
+ "Convert MinerU format to LLM input format operator."
+ "Transforms the content list JSON file generated by MinerU into a format suitable for LLM processing,"
+ "including flattening list items and re-indexing."
+ )
+
+ def _convert_json(self, input_file, output_file):
+ with open(input_file, 'r') as infile:
+ data = list(json.load(infile))
+
+ new_data = []
+ id = 0
+ for item in data:
+ item['id'] = id
+ item.pop('bbox', None)
+ item.pop('page_idx', None)
+ if item.get('type','') == 'list':
+ if item['sub_type'] == 'text':
+ for idx, list_item in enumerate(item.get('list_items', [])):
+ new_item = {
+ 'type': 'text',
+ 'text': list_item,
+ 'id': id + idx,
+ }
+ new_data.append(new_item)
+ id += len(item.get('list_items', []))
+ else:
+ new_data.append(item)
+ id += 1
+
+ with open(output_file, 'w') as outfile:
+ json.dump(new_data, outfile, ensure_ascii=False)
+
+ def run(self, storage: DataFlowStorage,
+ input_markdown_path_key,
+ output_converted_layout_key,
+ ):
+ dataframe = storage.read("dataframe")
+
+ for index, row in dataframe.iterrows():
+ input_json_path = row[input_markdown_path_key].replace('.md', '_content_list.json')
+ converted_path = input_json_path.replace('.json', '_converted.json')
+ self._convert_json(input_json_path, converted_path)
+ dataframe.at[index, output_converted_layout_key] = converted_path
+
+ with open(converted_path, 'r') as infile:
+ data = json.load(infile)
+ assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}"
+
+ storage.write(dataframe)
\ No newline at end of file
diff --git a/dataflow/operators/pdf2vqa/generate/qa_merger.py b/dataflow/operators/pdf2vqa/generate/qa_merger.py
new file mode 100644
index 00000000..e321dabf
--- /dev/null
+++ b/dataflow/operators/pdf2vqa/generate/qa_merger.py
@@ -0,0 +1,66 @@
+import os
+import json
+from dataflow.core import OperatorABC
+from dataflow.utils.registry import OPERATOR_REGISTRY
+from dataflow.utils.storage import DataFlowStorage
+from dataflow.utils.pdf2vqa.format_utils import merge_qa_pair, jsonl_to_md
+
+@OPERATOR_REGISTRY.register()
+class QA_Merger(OperatorABC):
+ def __init__(self, output_dir, strict_title_match=False):
+ self.output_dir = output_dir
+ self.strict_title_match = strict_title_match
+
+ @staticmethod
+ def get_desc(lang: str = "zh") -> str:
+ if lang == 'zh':
+ return (
+ "QA对合并算子。"
+ "将问题和答案的QA列表进行合并,生成最终的QA对文件,"
+ "并转换为Markdown格式。"
+ )
+ else:
+ return (
+ "QA pair merging operator."
+ "Merges question and answer QA lists to generate final QA pair files,"
+ "and converts them to Markdown format."
+ )
+
+ def run(self, storage: DataFlowStorage,
+ input_question_qalist_path_key,
+ input_answer_qalist_path_key,
+ input_name_key,
+ output_merged_qalist_path_key,
+ output_merged_md_path_key,
+ output_qa_item_key="qa_item" # 新增:展开后的 QA 内容列名
+ ):
+ dataframe = storage.read("dataframe")
+
+ # 为了能存储 list 对象,先初始化该列为 object 类型
+ dataframe[output_qa_item_key] = None
+ dataframe[output_qa_item_key] = dataframe[output_qa_item_key].astype(object)
+
+ for idx, row in dataframe.iterrows():
+ question_qalist_path = row[input_question_qalist_path_key]
+ answer_qalist_path = row[input_answer_qalist_path_key]
+ name = row[input_name_key]
+
+ output_merged_qalist_path = os.path.join(self.output_dir, name, "merged_qa_pairs.jsonl")
+ merge_qa_pair(question_qalist_path, answer_qalist_path, output_merged_qalist_path, strict_title_match=self.strict_title_match)
+
+ output_merged_md_path = os.path.join(self.output_dir, name, "merged_qa_pairs.md")
+ jsonl_to_md(output_merged_qalist_path, output_merged_md_path)
+
+ qa_pairs = []
+ if os.path.exists(output_merged_qalist_path):
+ with open(output_merged_qalist_path, 'r', encoding='utf-8') as f:
+ qa_pairs = [json.loads(line) for line in f]
+
+ dataframe.at[idx, output_qa_item_key] = qa_pairs
+
+ dataframe.loc[idx, output_merged_qalist_path_key] = output_merged_qalist_path
+ dataframe.loc[idx, output_merged_md_path_key] = output_merged_md_path
+
+ dataframe = dataframe.explode(output_qa_item_key).reset_index(drop=True)
+
+ storage.write(dataframe)
\ No newline at end of file
diff --git a/dataflow/operators/pdf2vqa/generate/vqa_extractor.py b/dataflow/operators/pdf2vqa/generate/vqa_extractor.py
deleted file mode 100644
index fcfd8fdf..00000000
--- a/dataflow/operators/pdf2vqa/generate/vqa_extractor.py
+++ /dev/null
@@ -1,530 +0,0 @@
-import os
-import json
-import re
-import pandas as pd
-import tiktoken
-import shutil
-import torch
-from pathlib import Path
-from typing import Literal
-from dataflow.core import OperatorABC
-from dataflow.utils.registry import OPERATOR_REGISTRY
-from dataflow.utils.storage import DataFlowStorage
-from dataflow import get_logger
-from dataflow.core import LLMServingABC
-from dataflow.prompts.pdf2vqa import QAExtractPrompt
-from dataflow.core.prompt import prompt_restrict
-from dataflow.utils.pdf2vqa.format_utils import merge_qa_pair, jsonl_to_md
-
-@prompt_restrict(QAExtractPrompt)
-@OPERATOR_REGISTRY.register()
-class VQAExtractor(OperatorABC):
- def __init__(self,
- llm_serving: LLMServingABC = None,
- mineru_backend: Literal["vlm-transformers","vlm-vllm-engine"] = "vlm-transformers",
- max_chunk_len: int = 128000,):
- self.logger = get_logger()
- self.llm_serving = llm_serving
- self.prompt_template = QAExtractPrompt()
- self.mineru_backend = mineru_backend
- self.max_chunk_len = max_chunk_len
-
- @staticmethod
- def get_desc(lang: str = "zh"):
- if lang == "zh":
- return (
- "该算子用于从试题或图文PDF文档中自动提取问答(VQA)结构化数据。\n\n"
- "功能说明:\n"
- "- 自动调用 MinerU 模型提取 PDF 文档的版面与内容布局信息。\n"
- "- 支持题目与答案的分离提取或交错(interleaved)模式处理。\n"
- "- 基于 LLM 生成章节结构化问答(、 标签格式)。\n"
- "- 自动进行内容清洗、图片路径替换与问答重建。\n"
- "- 支持结果过滤、合并及 Markdown 文档转换。\n\n"
- "输入要求:\n"
- "- DataFrame 中需包含 PDF 文件路径列,可为 question_pdf_path/answer_pdf_path 或 pdf_path。\n\n"
- "初始化参数:\n"
- "- llm_serving: LLM 推理服务实例,用于生成问答。\n"
- "- mineru_backend: MinerU 后端类型(可选值:\"vlm-transformers\" 或 \"vlm-vllm-engine\")。\n"
- "- max_chunk_len: 单批次最大token数量(默认128000)。\n\n"
- "运行参数(run):\n"
- "- input_question_pdf_path_key: 题目PDF路径列名(默认:\"question_pdf_path\")。\n"
- "- input_answer_pdf_path_key: 答案PDF路径列名(默认:\"answer_pdf_path\")。\n"
- "- input_pdf_path_key: 交错模式下的PDF路径列名(默认:\"pdf_path\")。\n"
- "- input_subject_key: 学科类别列名(默认:\"subject\")。\n"
- "- output_dir_key: 输出目录列名(默认:\"output_dir\")。\n"
- "- output_jsonl_key: 输出JSONL路径列名(默认:\"output_jsonl_path\")。\n"
- "- output_default_dir: 默认输出目录(默认:\"../vqa_output\")。\n\n"
- "输出:\n"
- "- 在 DataFrame 中新增一列,记录生成的VQA结构化问答JSONL文件路径。\n"
- "- 同时生成过滤后的Markdown文档和对应图片资源文件夹。"
- )
- elif lang == "en":
- return (
- "This operator extracts structured Visual Question Answering (VQA) data from exam or multimodal PDF documents.\n\n"
- "Functionality:\n"
- "- Automatically uses MinerU models to extract PDF layout and textual content.\n"
- "- Supports both separate (question/answer) and interleaved PDF processing modes.\n"
- "- Generates structured chapter-based QA pairs using an LLM (, tags).\n"
- "- Cleans and reconstructs QA content with proper image references.\n"
- "- Filters, merges, and converts the output into Markdown format.\n\n"
- "Input Requirements:\n"
- "- The input DataFrame must contain PDF path columns: either question_pdf_path/answer_pdf_path or pdf_path.\n\n"
- "Initialization Parameters:\n"
- "- llm_serving: Instance of LLM inference service used for QA generation.\n"
- "- mineru_backend: Backend type for MinerU ('vlm-transformers' or 'vlm-vllm-engine').\n"
- "- max_chunk_len: Maximum number of tokens per batch (default: 128000).\n\n"
- "Run Parameters:\n"
- "- input_question_pdf_path_key: Column name for question PDF path (default: 'question_pdf_path').\n"
- "- input_answer_pdf_path_key: Column name for answer PDF path (default: 'answer_pdf_path').\n"
- "- input_pdf_path_key: Column name for interleaved PDF path (default: 'pdf_path').\n"
- "- input_subject_key: Column name for subject type (default: 'subject').\n"
- "- output_dir_key: Column name for output directory (default: 'output_dir').\n"
- "- output_jsonl_key: Column name for output JSONL file path (default: 'output_jsonl_path').\n"
- "- output_default_dir: Default output directory (default: '../vqa_output').\n\n"
- "Output:\n"
- "- Adds a new column to the DataFrame containing paths to generated structured VQA JSONL files.\n"
- "- Also produces filtered Markdown documents and associated image folders."
- )
- else:
- return "VQAExtractor extracts structured VQA data from PDF documents and outputs filtered JSONL and Markdown files."
-
-
- def _convert_json(self, input_file, output_file):
- with open(input_file, 'r') as infile:
- data = list(json.load(infile))
-
- new_data = []
- id = 0
- for item in data:
- item['id'] = id
- item.pop('bbox', None)
- item.pop('page_idx', None)
- if item.get('type','') == 'list':
- if item['sub_type'] == 'text':
- for idx, list_item in enumerate(item.get('list_items', [])):
- new_item = {
- 'type': 'text',
- 'text': list_item,
- 'id': id + idx,
- }
- new_data.append(new_item)
- id += len(item.get('list_items', []))
- else:
- new_data.append(item)
- id += 1
-
- with open(output_file, 'w') as outfile:
- json.dump(new_data, outfile, ensure_ascii=False)
-
- def _count_tokens(self, text: str) -> int:
- enc = tiktoken.get_encoding("cl100k_base")
- return len(enc.encode(text))
-
- def _id_to_text(self, input_ids, input_json, image_prefix="images"):
- texts = []
- id_list = input_ids.replace(' ', '').split(',')
- for id in id_list:
- try:
- int(id)
- except:
- continue
- if int(id) < len(input_json):
- try:
- item = input_json[int(id)]
- except:
- continue
- if 'text' in item:
- texts.append(item['text'])
- elif 'img_path' in item:
- try:
- img_path = item.get('img_path', '')
- img_name = os.path.basename(img_path)
- new_path = f"{image_prefix}/{img_name}"
- texts.append(f"")
- except:
- pass
- elif item.get('type','') == 'list':
- if item['sub_type'] == 'text':
- try:
- texts.append(input_json[int(id)]['list_items'].pop(0))
- except:
- pass
- return '\n'.join(texts)
-
- def _extract_doc_layout(self, input_pdf_file_path: str, output_folder: str, mineru_backend: Literal["vlm-transformers","vlm-vllm-engine"] = "vlm-transformers"):
- """提取 PDF 的布局信息(合并自 VQAExtractDocLayoutMinerU)"""
- try:
- import mineru
- from mineru.cli.client import main as mineru_main
- except ImportError:
- raise Exception(
- """
- MinerU is not installed in this environment yet.
- Please refer to https://github.com/opendatalab/mineru to install.
- Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error.
- Please make sure you have GPU on your machine.
- """
- )
- try:
- from pypdf import PdfReader, PdfWriter, PageObject
- except ImportError:
- raise Exception(
- """
- pypdf is not installed in this environment yet.
- Please use pip install pypdf.
- """
- )
- try:
- from reportlab.pdfgen import canvas
- except ImportError:
- raise Exception(
- """
- reportlab is not installed in this environment yet.
- Please use pip install reportlab.
- """
- )
-
- os.environ['MINERU_MODEL_SOURCE'] = "local"
-
- MinerU_Version = {"pipeline": "auto", "vlm-transformers": "vlm", "vlm-vllm-engine": "vlm"}
-
- if mineru_backend == "pipeline":
- raise ValueError("The 'pipeline' backend is not supported due to its incompatible output format. Please use 'vlm-transformers' or 'vlm-vllm-engine' instead.")
-
- raw_file = Path(input_pdf_file_path)
- pdf_name = raw_file.stem
- intermediate_dir = output_folder
- args = [
- "-p", str(raw_file),
- "-o", str(intermediate_dir),
- "-b", mineru_backend,
- "--source", "local"
- ]
- if mineru_backend == "vlm-vllm-engine":
- assert torch.cuda.is_available(), "MinerU vlm-vllm-engine backend requires GPU support."
- args += ["--tensor-parallel-size", "2" if torch.cuda.device_count() >= 2 else "1"]
-
- try:
- mineru_main(args)
- except SystemExit as e:
- if e.code != 0:
- raise RuntimeError(f"MinerU execution failed with exit code: {e.code}")
-
- output_json_file = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend], f"{pdf_name}_content_list.json")
- output_layout_file = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend], f"{pdf_name}_layout.pdf")
- return output_json_file, output_layout_file
-
- def _convert_response(self, input_response, input_json_path, image_prefix="images"):
- qa_list = []
- with open(input_json_path, 'r') as infile:
- input_json = list(json.load(infile))
- # 提取title
- for chapter_block in re.findall(r'(.*?)', input_response, flags=re.DOTALL):
- title = re.search(r'(.*?)', chapter_block, flags=re.DOTALL)
- if title:
- chapter_title = self._id_to_text(title.group(1).strip(), input_json, image_prefix)
- else:
- chapter_title = ""
- # 找出所有 qa_pair 块
- for pair in re.findall(r'(.*?)', chapter_block, flags=re.DOTALL):
- # 提取 question 部分
- q_match = re.search(r'(.*?)', pair, flags=re.DOTALL)
- # 提取 answer 部分
- a_match = re.search(r'(.*?)', pair, flags=re.DOTALL)
- # 提取solution部分
- s_match = re.search(r'(.*?)', pair, flags=re.DOTALL)
- # 提取label
- label_match = re.search(r'', pair, flags=re.DOTALL)
- if not ((q_match and label_match) or (a_match and label_match) or (s_match and label_match)):
- continue
- label = label_match.group(1).strip()
- qa_list.append({
- 'question': self._id_to_text(q_match.group(1).strip(), input_json, image_prefix) if q_match else "",
- 'answer': a_match.group(1).strip() if a_match else "",
- 'solution': self._id_to_text(s_match.group(1).strip(), input_json, image_prefix) if s_match else "",
- 'label': label,
- 'chapter_title': chapter_title
- })
- return qa_list
-
- def run(self, storage: DataFlowStorage,
- input_question_pdf_path_key: str = "question_pdf_path",
- input_answer_pdf_path_key: str = "answer_pdf_path",
- input_pdf_path_key: str = "pdf_path", # 支持 interleaved 模式的单一 pdf_path
- input_subject_key: str = "subject",
- output_dir_key: str = "output_dir",
- output_jsonl_key: str = "output_jsonl_path",
- output_default_dir: str = "../vqa_output") -> list:
- dataframe = storage.read("dataframe")
-
- # 支持两种输入格式:question_pdf_path/answer_pdf_path 或 pdf_path
- if input_question_pdf_path_key not in dataframe.columns and input_pdf_path_key not in dataframe.columns:
- raise ValueError(f"Column '{input_question_pdf_path_key}' or '{input_pdf_path_key}' not found in dataframe")
-
- # ========== Stage 1: 预处理(任务扩展 + Layout 提取) ==========
- expanded_rows = []
-
- for idx, row in dataframe.iterrows():
- # 优先使用 question_pdf_path,如果没有则使用 pdf_path(interleaved 模式)
- if input_question_pdf_path_key in dataframe.columns:
- question_pdf_path = row[input_question_pdf_path_key]
- answer_pdf_path = row.get(input_answer_pdf_path_key, question_pdf_path)
- else:
- # interleaved 模式:使用同一个 pdf_path
- question_pdf_path = row[input_pdf_path_key]
- answer_pdf_path = question_pdf_path
-
- subject = row.get(input_subject_key, "math")
- output_root = row.get(output_dir_key, output_default_dir)
- interleaved = (question_pdf_path == answer_pdf_path)
-
- os.makedirs(output_root, exist_ok=True)
-
- # Question task
- q_outdir = os.path.join(output_root, "question")
- os.makedirs(q_outdir, exist_ok=True)
-
- # Layout 提取
- q_json_path, _ = self._extract_doc_layout(
- input_pdf_file_path=question_pdf_path,
- output_folder=q_outdir,
- mineru_backend=self.mineru_backend
- )
-
- expanded_rows.append({
- "pdf_path": question_pdf_path,
- "mode": "question",
- "interleaved": interleaved,
- "subject": subject,
- "output_dir": q_outdir,
- "output_root": output_root,
- "json_path": q_json_path
- })
-
- # Answer task (if not interleaved)
- if not interleaved:
- a_outdir = os.path.join(output_root, "answer")
- os.makedirs(a_outdir, exist_ok=True)
-
- # Layout 提取
- a_json_path, _ = self._extract_doc_layout(
- input_pdf_file_path=answer_pdf_path,
- output_folder=a_outdir,
- mineru_backend=self.mineru_backend
- )
-
- expanded_rows.append({
- "pdf_path": answer_pdf_path,
- "mode": "answer",
- "interleaved": interleaved,
- "subject": subject,
- "output_dir": a_outdir,
- "output_root": output_root,
- "json_path": a_json_path
- })
-
- # ========== Stage 2: QA 提取 ==========
- json_paths = [row["json_path"] for row in expanded_rows]
- subjects = [row["subject"] for row in expanded_rows]
-
- user_inputs = []
- split_metadata = []
-
- for idx, input_json_path in enumerate(json_paths):
- subject = subjects[idx] if idx < len(subjects) else subjects[0] if subjects else "math"
- system_prompt = self.prompt_template.build_prompt(subject)
- system_prompt_len = self._count_tokens(system_prompt)
-
- converted_path = input_json_path.replace('.json', '_converted.json')
- self._convert_json(input_json_path, converted_path)
-
- with open(converted_path, 'r') as infile:
- data = json.load(infile)
- assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}"
-
- # 分段处理
- current_chunk, current_len = [], system_prompt_len
- chunks = []
-
- for item in data:
- text = json.dumps(item, ensure_ascii=False)
- item_len = self._count_tokens(text)
- if current_len + item_len > self.max_chunk_len and current_chunk:
- chunks.append(current_chunk)
- current_chunk, current_len = [], system_prompt_len
- current_chunk.append(item)
- current_len += item_len
-
- if current_chunk:
- chunks.append(current_chunk)
-
- split_metadata.append(len(chunks))
-
- for chunk in chunks:
- user_inputs.append({
- 'user_input': json.dumps(chunk, ensure_ascii=False),
- 'system_prompt': system_prompt
- })
-
- # 批量生成
- responses = [None] * len(user_inputs)
- current_batch = []
- current_batch_indices = []
- current_system_prompt = None
-
- for idx, item in enumerate(user_inputs):
- user_input = item['user_input']
- system_prompt = item['system_prompt']
-
- if current_system_prompt is None:
- current_system_prompt = system_prompt
- current_batch = [user_input]
- current_batch_indices = [idx]
- elif system_prompt == current_system_prompt:
- current_batch.append(user_input)
- current_batch_indices.append(idx)
- else:
- # 处理当前批次
- batch_responses = self.llm_serving.generate_from_input(user_inputs=current_batch, system_prompt=current_system_prompt)
- for batch_idx, resp in zip(current_batch_indices, batch_responses):
- responses[batch_idx] = resp
- # 开始新批次
- current_system_prompt = system_prompt
- current_batch = [user_input]
- current_batch_indices = [idx]
-
- # 处理最后一批
- if current_batch:
- batch_responses = self.llm_serving.generate_from_input(user_inputs=current_batch, system_prompt=current_system_prompt)
- for batch_idx, resp in zip(current_batch_indices, batch_responses):
- responses[batch_idx] = resp
-
- # 按 split_metadata 还原
- recombined_responses = []
- idx = 0
- for num_chunks in split_metadata:
- merged_text = "\n".join(responses[idx: idx + num_chunks])
- recombined_responses.append(merged_text)
- idx += num_chunks
-
- # ========== Stage 3: 后处理(Response 转换 + 合并和过滤) ==========
- # Response 转换
- qa_lists = []
- for idx, (response, row) in enumerate(zip(recombined_responses, expanded_rows)):
- json_path = row["json_path"]
- output_dir = row["output_dir"]
- mode = row["mode"]
- output_root = row["output_root"]
-
- image_prefix = f"{mode}_images"
- converted_json_path = json_path.replace('.json', '_converted.json')
- qa_list = self._convert_response(response, converted_json_path, image_prefix)
-
- # 复制图片
- src_dir = os.path.join(output_dir, Path(json_path).stem).replace('_content_list','')
- src_images = os.path.join(src_dir, 'vlm', 'images')
- dst_images = os.path.join(output_root, image_prefix)
-
- try:
- if os.path.exists(src_images):
- if os.path.exists(dst_images):
- shutil.rmtree(dst_images)
- shutil.copytree(src_images, dst_images)
- else:
- self.logger.warning(f"Source images dir does not exist: {src_images}")
- except Exception as e:
- self.logger.warning(f"Failed to copy images from {src_images} to {dst_images}: {e}")
-
- qa_lists.append(qa_list)
-
- # 按 output_root 分组处理合并和过滤
- output_groups = {}
- for idx, (qa_list, row) in enumerate(zip(qa_lists, expanded_rows)):
- output_root = row["output_root"]
- mode = row["mode"]
- interleaved = row["interleaved"]
- output_dir = row["output_dir"]
-
- if output_root not in output_groups:
- output_groups[output_root] = {
- "question": None,
- "answer": None,
- "interleaved": interleaved
- }
-
- if mode == "question":
- output_groups[output_root]["question"] = (qa_list, output_dir)
- elif mode == "answer":
- output_groups[output_root]["answer"] = (qa_list, output_dir)
-
- # 处理每个 output_root
- result_paths_dict = {}
- for output_root, group_info in output_groups.items():
- q_qa_list, q_output_dir = group_info["question"] if group_info["question"] else (None, None)
- a_qa_list, a_output_dir = group_info["answer"] if group_info["answer"] else (None, None)
- interleaved = group_info["interleaved"]
-
- # 写入 question jsonl
- q_jsonl_path = os.path.join(output_root, "vqa_extracted_questions.jsonl")
- if q_qa_list:
- with open(q_jsonl_path, 'w', encoding='utf-8') as f:
- for item in q_qa_list:
- f.write(json.dumps(item, ensure_ascii=False) + '\n')
-
- # 写入 answer jsonl(如果不是 interleaved)
- a_jsonl_path = None
- if not interleaved and a_qa_list:
- a_jsonl_path = os.path.join(output_root, "vqa_extracted_answers.jsonl")
- with open(a_jsonl_path, 'w', encoding='utf-8') as f:
- for item in a_qa_list:
- f.write(json.dumps(item, ensure_ascii=False) + '\n')
-
- # 合并
- merged_jsonl = os.path.join(output_root, "vqa_merged_qa_pairs.jsonl")
- if not interleaved and a_jsonl_path:
- merge_qa_pair(q_jsonl_path, a_jsonl_path, merged_jsonl)
- else:
- os.system(f"cp {q_jsonl_path} {merged_jsonl}")
-
- # 过滤
- filtered_items = []
- total_count = 0
- with open(merged_jsonl, 'r', encoding='utf-8') as f:
- for line in f:
- total_count += 1
- item = json.loads(line)
- if item.get('question','').strip() and (item.get('answer','').strip() or item.get('solution','').strip()):
- filtered_items.append(item)
-
- self.logger.info(f"Before filter: {total_count}, After filter: {len(filtered_items)}")
-
- filtered_jsonl = os.path.join(output_root, "vqa_filtered_qa_pairs.jsonl")
- with open(filtered_jsonl, 'w', encoding='utf-8') as f:
- for item in filtered_items:
- f.write(json.dumps(item, ensure_ascii=False) + '\n')
-
- # 转换为 markdown
- md_output = os.path.join(output_root, "vqa_filtered_qa_pairs.md")
- jsonl_to_md(filtered_jsonl, md_output)
-
- result_paths_dict[output_root] = filtered_jsonl
-
- # 为原始 dataframe 的每一行分配结果路径
- result_paths = []
- for idx, row in dataframe.iterrows():
- if input_question_pdf_path_key in dataframe.columns:
- question_pdf_path = row[input_question_pdf_path_key]
- answer_pdf_path = row.get(input_answer_pdf_path_key, question_pdf_path)
- else:
- question_pdf_path = row[input_pdf_path_key]
- answer_pdf_path = question_pdf_path
-
- output_root = row.get(output_dir_key, output_default_dir)
- result_paths.append(result_paths_dict.get(output_root))
-
- dataframe[output_jsonl_key] = result_paths
- output_file = storage.write(dataframe)
- self.logger.info(f"VQA extraction complete. Results saved to {output_file}")
-
- return [output_jsonl_key,]
-
diff --git a/dataflow/prompts/pdf2vqa.py b/dataflow/prompts/pdf2vqa.py
index 42a6c3a3..3609a029 100644
--- a/dataflow/prompts/pdf2vqa.py
+++ b/dataflow/prompts/pdf2vqa.py
@@ -92,9 +92,9 @@ class QAExtractPrompt(PromptABC):
def __init__(self):
pass
- def build_prompt(self, subject: str = "math") -> str:
+ def build_prompt(self) -> str:
PROMPT = f"""
- You are an expert in {subject}. You are given a json file. Your task is to segment the content, insert images tags, and extract labels:
+ You are an expert in answer college-level questions. You are given a json file. Your task is to segment the content, insert images tags, and extract labels:
1. Every json item has an "id" field. Your main task is to output this field.
2. You need to segment the content into multiple ``…`` blocks, each containing a question and its corresponding answer with solution.
3. If the problem or answer/solution is not complete, omit them. An answer/solution should be considered complete as long as either the answer or solution exists.
@@ -120,7 +120,7 @@ def build_prompt(self, subject: str = "math") -> str:
- Always enclose qa pairs in a ``…`` block, where MAIN_TITLE_ID is the id of the chapter title or section title.
- Normally, chapter/section titles appear before the questions/answers in an independent json item.
- There could be multiple ``…`` blocks if multiple chapters/sections exist.
-- **Any title followed by a question/answer whose label/number is not 1, or title with a score, should NOT be extracted.**
+- **Any title followed by a question/answer whose label/number is not 1, or title with a score such as "一、选择题(每题1分,共10分)", should NOT be extracted.**
- Do not use nested titles.
- Leave the title blank if there is no chapter title.
** About figures/diagrams **
diff --git a/dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py b/dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py
index e186a2e8..929f3bbc 100644
--- a/dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py
+++ b/dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py
@@ -1,48 +1,107 @@
-import os
-import sys
+from dataflow.operators.knowledge_cleaning import FileOrURLToMarkdownConverterBatch
from dataflow.serving import APILLMServing_request
from dataflow.utils.storage import FileStorage
-from dataflow.operators.pdf2vqa import VQAExtractor
+from dataflow.operators.pdf2vqa import MinerU2LLMInputOperator, LLMOutputParser, QA_Merger
+from dataflow.operators.core_text import ChunkedPromptedGenerator
-class VQA_extract_optimized_pipeline:
+from dataflow.pipeline import PipelineABC
+from dataflow.prompts.pdf2vqa import QAExtractPrompt
+
+class PDF_VQA_extract_optimized_pipeline(PipelineABC):
def __init__(self):
+ super().__init__()
self.storage = FileStorage(
- first_entry_file_name="../example_data/PDF2VQAPipeline/vqa_extract_test.jsonl",
+ first_entry_file_name="./example_data/PDF2VQAPipeline/vqa_extract_test.jsonl",
cache_path="./cache",
file_name_prefix="vqa",
cache_type="jsonl",
)
self.llm_serving = APILLMServing_request(
- api_url="https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
+ api_url="http://123.129.219.111:3000/v1/chat/completions",
key_name_of_api_key="DF_API_KEY",
model_name="gemini-2.5-pro",
max_workers=100,
)
- self.vqa_extractor = VQAExtractor(
+ self.vqa_extract_prompt = QAExtractPrompt()
+
+ self.mineru_executor = FileOrURLToMarkdownConverterBatch(intermediate_dir = "intermediate", mineru_backend="vlm-vllm-engine")
+ self.input_formatter = MinerU2LLMInputOperator()
+ self.vqa_extractor = ChunkedPromptedGenerator(
llm_serving=self.llm_serving,
- mineru_backend='vlm-vllm-engine',
- max_chunk_len=128000
+ system_prompt = self.vqa_extract_prompt.build_prompt(),
+ max_chunk_len=128000,
)
-
+ self.llm_output_question_parser = LLMOutputParser(mode="question", output_dir="./cache", intermediate_dir="intermediate")
+ self.llm_output_answer_parser = LLMOutputParser(mode="answer", output_dir="./cache", intermediate_dir="intermediate")
+ self.qa_merger = QA_Merger(output_dir="./cache", strict_title_match=False)
def forward(self):
- # 单一算子:包含预处理、QA提取、后处理的所有功能
+ # 目前的处理逻辑是:MinerU处理问题-MinerU处理答案-格式化问题文本-格式化答案文本-问题文本输入LLM-答案文本输入LLM-解析问题输出-解析答案输出-合并问答对
+ # 由于问答对可能来自同一份pdf,也有可能来自不同pdf,而dataflow目前不支持分支,因此这里只能将question和answer的pdf都进行一次处理,
+ # 即使是同一份pdf也会被处理两次,最后再合并问答对。
+ # 未来会再思考如何优化这个流程,避免重复处理同一份pdf,提升性能。
+
+ self.mineru_executor.run(
+ storage=self.storage.step(),
+ input_key="question_pdf_path",
+ output_key="question_markdown_path",
+ )
+ self.mineru_executor.run(
+ storage=self.storage.step(),
+ input_key="answer_pdf_path",
+ output_key="answer_markdown_path",
+ )
+ self.input_formatter.run(
+ storage=self.storage.step(),
+ input_markdown_path_key="question_markdown_path",
+ output_converted_layout_key="converted_question_layout_path",
+ )
+ self.input_formatter.run(
+ storage=self.storage.step(),
+ input_markdown_path_key="answer_markdown_path",
+ output_converted_layout_key="converted_answer_layout_path",
+ )
self.vqa_extractor.run(
storage=self.storage.step(),
- input_question_pdf_path_key="question_pdf_path",
- input_answer_pdf_path_key="answer_pdf_path",
- input_pdf_path_key="pdf_path", # 支持 interleaved 模式
- input_subject_key="subject",
- output_dir_key="output_dir",
- output_jsonl_key="output_jsonl_path",
+ input_path_key="converted_question_layout_path",
+ output_path_key="vqa_extracted_questions_path",
+ )
+ self.vqa_extractor.run(
+ storage=self.storage.step(),
+ input_path_key="converted_answer_layout_path",
+ output_path_key="vqa_extracted_answers_path",
+ )
+ self.llm_output_question_parser.run(
+ storage=self.storage.step(),
+ input_response_path_key="vqa_extracted_questions_path",
+ input_converted_layout_path_key="converted_question_layout_path",
+ input_name_key="name",
+ output_qalist_path_key="extracted_questions_path",
+ )
+ self.llm_output_answer_parser.run(
+ storage=self.storage.step(),
+ input_response_path_key="vqa_extracted_answers_path",
+ input_converted_layout_path_key="converted_answer_layout_path",
+ input_name_key="name",
+ output_qalist_path_key="extracted_answers_path",
+ )
+ self.qa_merger.run(
+ storage=self.storage.step(),
+ input_question_qalist_path_key="extracted_questions_path",
+ input_answer_qalist_path_key="extracted_answers_path",
+ input_name_key="name",
+ output_merged_qalist_path_key="output_merged_qalist_path",
+ output_merged_md_path_key="output_merged_md_path",
+ output_qa_item_key="qa_pair",
)
if __name__ == "__main__":
- # jsonl中每一行包含question_pdf_path, answer_pdf_path, subject (math, physics, chemistry, ...), output_dir
+ # jsonl中每一行包含question_pdf_path, answer_pdf_path, name (math1, math2, physics1, chemistry1, ...)
# 如果question和answer在同一份pdf中,请将question_pdf_path和answer_pdf_path设置为相同的路径,会自动切换为interleaved模式
- pipeline = VQA_extract_optimized_pipeline()
- pipeline.forward()
+ pipeline = PDF_VQA_extract_optimized_pipeline()
+ pipeline.compile()
+ pipeline.forward()
\ No newline at end of file
diff --git a/dataflow/utils/pdf2vqa/format_utils.py b/dataflow/utils/pdf2vqa/format_utils.py
index 94294d15..c1b8138e 100644
--- a/dataflow/utils/pdf2vqa/format_utils.py
+++ b/dataflow/utils/pdf2vqa/format_utils.py
@@ -1,11 +1,29 @@
import json
import re
-def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl):
+def refine_title(title: str, strict_title_match=False):
+ # TODO : 这里可能需要更复杂的title清洗逻辑
+ # 删除title中的空格与换行符
+ title = re.sub(r'\s+', '', title)
+ if not strict_title_match:
+ try:
+ # 优先提取阿拉伯数字章节编号(如1.1,2等)
+ new_title = re.search(r"\d+\.\d+|\d+", title).group()
+ except:
+ try:
+ # 其次提取中文数字章节编号(如六、二十四等)
+ new_title = re.search(r'[一二三四五六七八九零十百]+', title).group()
+ except:
+ new_title = title
+ title = new_title
+ return title
+
+def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl, strict_title_match=False):
+ already_complete_count = 0
with open(question_jsonl, 'r', encoding='utf-8') as q_file, open(answer_jsonl, 'r', encoding='utf-8') as a_file, open(output_jsonl, 'w', encoding='utf-8') as out_file:
chapter_id = 0
chapter_title = ""
- label = 1000000
+ label = float('inf')
questions = {}
answers = {}
for line in q_file:
@@ -29,13 +47,11 @@ def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl):
# 如果题号增加,章节标题却发生变化,说明可能错误提取了子标题。因此继续使用之前的章节标题。
data["chapter_title"] = chapter_title
label = data["label"]
- data["chapter_id"] = chapter_id
- # TODO : 这里可能需要更复杂的title清洗逻辑
- # 删除title中的空格与换行符
- data["chapter_title"] = re.sub(r'\s+', '', data["chapter_title"])
- if data['label'] > 0 and data["chapter_title"]:
+ data["chapter_title"] = refine_title(data["chapter_title"], strict_title_match)
+ if data['label'] > 0:
# 已经完整的题目直接写入out_file
if data["answer"] or data["solution"]:
+ already_complete_count += 1
qa_pair = {
"question_chapter_title": data["chapter_title"],
"answer_chapter_title": data["chapter_title"],
@@ -51,7 +67,7 @@ def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl):
chapter_id = 0
chapter_title = ""
- label = 1000000
+ label = float('inf')
for line in a_file:
data = json.loads(line)
label_match = re.search(r'\d+', data["label"])
@@ -73,10 +89,7 @@ def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl):
# 如果题号增加,章节标题却发生变化,说明可能错误提取了子标题。因此继续使用之前的章节标题。
data["chapter_title"] = chapter_title
label = data["label"]
- data["chapter_id"] = chapter_id
- # TODO : 这里可能需要更复杂的title清洗逻辑
- # 删除title中的空格与换行符
- data["chapter_title"] = re.sub(r'\s+', '', data["chapter_title"])
+ data["chapter_title"] = refine_title(data["chapter_title"], strict_title_match)
# 动态更新,防止错误的重复label覆盖掉之前的solution或answer
if data['label'] > 0:
if not answers.get((data["chapter_title"], data['label'])):
@@ -99,7 +112,7 @@ def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl):
}
out_file.write(json.dumps(qa_pair, ensure_ascii=False) + '\n')
- print(f"Merged QA pairs: {len(questions.keys() & answers.keys())}")
+ print(f"Merged QA pairs: {len(questions.keys() & answers.keys()) + already_complete_count}")
def jsonl_to_md(jsonl_file, md_file):
with open(jsonl_file, 'r', encoding='utf-8') as in_file, open(md_file, 'w', encoding='utf-8') as out_file: