diff --git a/dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl b/dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl index 40866a8d..d535d8ae 100644 --- a/dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl +++ b/dataflow/example/PDF2VQAPipeline/vqa_extract_test.jsonl @@ -1,2 +1,2 @@ -{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "subject": "math", "output_dir": "../vqa_output_test/math1"} -{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "subject": "math", "output_dir": "../vqa_output_test/math2"} +{"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "name": "math1"} +{"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "name": "math2"} diff --git a/dataflow/operators/core_text/__init__.py b/dataflow/operators/core_text/__init__.py index 592469c9..c9429092 100644 --- a/dataflow/operators/core_text/__init__.py +++ b/dataflow/operators/core_text/__init__.py @@ -2,6 +2,7 @@ if TYPE_CHECKING: from .generate.prompted_generator import PromptedGenerator + from .generate.chunked_prompted_generator import ChunkedPromptedGenerator from .generate.format_str_prompted_generator import FormatStrPromptedGenerator from .generate.random_domain_knowledge_row_generator import RandomDomainKnowledgeRowGenerator from .generate.text2qa_generator import Text2QAGenerator diff --git a/dataflow/operators/core_text/generate/chunked_prompted_generator.py b/dataflow/operators/core_text/generate/chunked_prompted_generator.py new file mode 100644 index 00000000..d54c5058 --- /dev/null +++ b/dataflow/operators/core_text/generate/chunked_prompted_generator.py @@ -0,0 +1,147 @@ +import pandas as pd +import tiktoken +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow import get_logger +from pathlib import Path + +from dataflow.utils.storage import DataFlowStorage +from dataflow.core import OperatorABC +from dataflow.core import LLMServingABC + +@OPERATOR_REGISTRY.register() +class ChunkedPromptedGenerator(OperatorABC): + """ + 基于Prompt的生成算子,支持自动chunk输入。 + - 使用tiktoken或HuggingFace的AutoTokenizer计算token数量; + - 若输入超过max_chunk_len,采用递归二分法切分; + - 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径; + - 生成结果是以separator拼接的字符串。 + """ + + def __init__( + self, + llm_serving: LLMServingABC, + system_prompt: str = "You are a helpful agent.", + json_schema: dict = None, + max_chunk_len: int = 128000, + enc = tiktoken.get_encoding("cl100k_base"), # 支持len(enc.encode(text))的tokenizer都可以,比如tiktoken或HuggingFace的AutoTokenizer + seperator: str = "\n", + ): + self.logger = get_logger() + self.llm_serving = llm_serving + self.system_prompt = system_prompt + self.json_schema = json_schema + self.max_chunk_len = max_chunk_len + self.enc = enc + self.separator = seperator + + @staticmethod + def get_desc(lang: str = "zh"): + if lang == "zh": + return ( + "基于提示词的生成算子,支持长文本自动分chunk。" + "采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。" + "从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。" + "多个生成结果以separator拼接成最终输出字符串。" + "输入参数:\n" + "- llm_serving:LLM服务对象,需实现LLMServingABC接口\n" + "- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n" + "- max_chunk_len:单个chunk的最大token长度,默认为128000\n" + "- input_path_key:输入文件路径字段名,默认为'input_path'\n" + "- output_path_key:输出文件路径字段名,默认为'output_path'\n" + "- json_schema:可选,生成结果的JSON Schema约束\n" + "- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n" + "- separator:chunk结果拼接分隔符,默认为换行符\n" + ) + else: + return ( + "Prompt-based generator with recursive chunk splitting." + "Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens." + "Reads content from specified input file paths and saves generated results to designated output file paths." + "Multiple generated results are joined as a string using the specified separator." + "Input Parameters:\n" + "- llm_serving: LLM serving object implementing LLMServingABC interface\n" + "- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n" + "- max_chunk_len: Maximum token length per chunk, default is 128000\n" + "- input_path_key: Field name for input file path, default is 'input_path'\n" + "- output_path_key: Field name for output file path, default is 'output_path'\n" + "- json_schema: Optional JSON Schema constraint for generated results\n" + "- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n" + "- separator: Separator for chunk results, default is newline character\n" + ) + + # === token计算 === + def _count_tokens(self, text: str) -> int: + return len(self.enc.encode(text)) + + # === 递归二分分chunk === + def _split_recursive(self, text: str) -> list[str]: + """递归地将文本拆分为不超过max_chunk_len的多个chunk""" + token_len = self._count_tokens(text) + if token_len <= self.max_chunk_len: + return [text] + else: + mid = len(text) // 2 + left, right = text[:mid], text[mid:] + return self._split_recursive(left) + self._split_recursive(right) + + def run( + self, + storage: DataFlowStorage, + input_path_key, + output_path_key, + ): + self.logger.info("Running ChunkedPromptedGenerator...") + dataframe = storage.read("dataframe") + self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.") + + all_generated_results = [] + + all_llm_inputs = [] + row_chunk_map = [] # 记录每个row对应的chunk数量 + + # === 先收集所有chunk === + for i, row in dataframe.iterrows(): + raw_content = Path(row[input_path_key]).read_text(encoding='utf-8') + + chunks = self._split_recursive(raw_content) + self.logger.info(f"Row {i}: split into {len(chunks)} chunks") + + system_prompt = self.system_prompt + "\n" + llm_inputs = [system_prompt + chunk for chunk in chunks] + all_llm_inputs.extend(llm_inputs) + row_chunk_map.append(len(chunks)) + + # === 一次性并发调用 === + self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate") + + try: + if self.json_schema: + all_responses = self.llm_serving.generate_from_input( + all_llm_inputs, json_schema=self.json_schema + ) + else: + all_responses = self.llm_serving.generate_from_input(all_llm_inputs) + except Exception as e: + self.logger.error(f"Global generation failed: {e}") + all_generated_results = [[] for _ in range(len(dataframe))] + else: + # === 按row重新划分responses === + all_generated_results = [] + idx = 0 + for num_chunks in row_chunk_map: + if num_chunks == 0: + all_generated_results.append([]) + else: + all_generated_results.append(all_responses[idx:idx + num_chunks]) + idx += num_chunks + + for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results): + output_path = row[input_path_key].split('.')[0] + '_llm_output.txt' + with open(output_path, 'w', encoding='utf-8') as f: + f.write(self.separator.join(gen_results)) + dataframe.at[i, output_path_key] = output_path + + output_file = storage.write(dataframe) + self.logger.info(f"Generation complete. Output saved to {output_file}") + return output_path_key \ No newline at end of file diff --git a/dataflow/operators/core_text/generate/prompted_generator.py b/dataflow/operators/core_text/generate/prompted_generator.py index cc5ece8f..9cf26a80 100644 --- a/dataflow/operators/core_text/generate/prompted_generator.py +++ b/dataflow/operators/core_text/generate/prompted_generator.py @@ -1,6 +1,7 @@ import pandas as pd from dataflow.utils.registry import OPERATOR_REGISTRY from dataflow import get_logger +from pathlib import Path from dataflow.utils.storage import DataFlowStorage from dataflow.core import OperatorABC diff --git a/dataflow/operators/pdf2vqa/__init__.py b/dataflow/operators/pdf2vqa/__init__.py index 8236e494..a385917b 100644 --- a/dataflow/operators/pdf2vqa/__init__.py +++ b/dataflow/operators/pdf2vqa/__init__.py @@ -1,7 +1,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .generate.vqa_extractor import VQAExtractor + from .generate.mineru_to_llm_input_operator import MinerU2LLMInputOperator + from .generate.llm_output_parser import LLMOutputParser + from .generate.qa_merger import QA_Merger else: diff --git a/dataflow/operators/pdf2vqa/generate/llm_output_parser.py b/dataflow/operators/pdf2vqa/generate/llm_output_parser.py new file mode 100644 index 00000000..7565db78 --- /dev/null +++ b/dataflow/operators/pdf2vqa/generate/llm_output_parser.py @@ -0,0 +1,141 @@ +import os +import json +import re +import shutil +from pathlib import Path +from typing import Literal +from dataflow.core import OperatorABC +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow.utils.storage import DataFlowStorage +from dataflow import get_logger + +@OPERATOR_REGISTRY.register() +class LLMOutputParser(OperatorABC): + def __init__(self, + mode: Literal['question', 'answer'], + output_dir, + intermediate_dir: str = "intermediate", + ): + self.logger = get_logger() + self.mode = mode + self.output_dir = output_dir + self.intermediate_dir = intermediate_dir + + @staticmethod + def get_desc(lang: str = "zh") -> str: + if lang == 'zh': + return ( + "LLM输出解析算子。" + "将LLM生成的包含题目和答案ID的响应文本," + "转换为结构化的QA列表,并复制相关图片到输出目录。" + ) + else: + return ( + "LLM output parsing operator." + "Converts LLM-generated response text containing question and answer IDs" + "into a structured QA list and copies related images to the output directory." + ) + + def _id_to_text(self, input_ids, input_json, image_prefix="images"): + texts = [] + id_list = input_ids.replace(' ', '').split(',') + for id in id_list: + try: + int(id) + except: + continue + if int(id) < len(input_json): + try: + item = input_json[int(id)] + except: + continue + if 'text' in item: + texts.append(item['text']) + elif 'img_path' in item: + try: + img_path = item.get('img_path', '') + img_name = os.path.basename(img_path) + new_path = f"{image_prefix}/{img_name}" + texts.append(f"![{' '.join(item.get('image_caption','image'))}]({new_path})") + except: + pass + elif item.get('type','') == 'list': + if item['sub_type'] == 'text': + try: + texts.append(input_json[int(id)]['list_items'].pop(0)) + except: + pass + return '\n'.join(texts) + + def _convert_response(self, input_response, input_json_path, image_prefix="images"): + qa_list = [] + with open(input_json_path, 'r') as infile: + input_json = list(json.load(infile)) + # 提取title + for chapter_block in re.findall(r'(.*?)', input_response, flags=re.DOTALL): + title = re.search(r'(.*?)', chapter_block, flags=re.DOTALL) + if title: + chapter_title = self._id_to_text(title.group(1).strip(), input_json, image_prefix) + else: + chapter_title = "" + # 找出所有 qa_pair 块 + for pair in re.findall(r'(.*?)', chapter_block, flags=re.DOTALL): + # 提取 question 部分 + q_match = re.search(r'(.*?)', pair, flags=re.DOTALL) + # 提取 answer 部分 + a_match = re.search(r'(.*?)', pair, flags=re.DOTALL) + # 提取solution部分 + s_match = re.search(r'(.*?)', pair, flags=re.DOTALL) + # 提取label + label_match = re.search(r'', pair, flags=re.DOTALL) + if not ((q_match and label_match) or (a_match and label_match) or (s_match and label_match)): + continue + label = label_match.group(1).strip() + qa_list.append({ + 'question': self._id_to_text(q_match.group(1).strip(), input_json, image_prefix) if q_match else "", + 'answer': a_match.group(1).strip() if a_match else "", + 'solution': self._id_to_text(s_match.group(1).strip(), input_json, image_prefix) if s_match else "", + 'label': label, + 'chapter_title': chapter_title + }) + return qa_list + + def run(self, storage: DataFlowStorage, + input_response_path_key, + input_converted_layout_path_key, + input_name_key, + output_qalist_path_key, + ): + dataframe = storage.read("dataframe") + + # Response 转换 + for idx, row in dataframe.iterrows(): + converted_json_path = row[input_converted_layout_path_key] + response = Path(row[input_response_path_key]).read_text(encoding='utf-8') + name = row[input_name_key] + + image_prefix = os.path.join(name, f"{self.mode}_images") + qa_list = self._convert_response(response, converted_json_path, image_prefix) + output_qalist_path = os.path.join(self.output_dir, name, f"extracted_{self.mode}s.jsonl") + os.makedirs(os.path.dirname(output_qalist_path), exist_ok=True) + with open(output_qalist_path, 'w') as outfile: + for qa in qa_list: + json.dump(qa, outfile, ensure_ascii=False) + outfile.write('\n') + + # 复制图片 + src_dir = os.path.join(self.intermediate_dir, 'mineru', Path(converted_json_path).stem).replace('_content_list_converted','') + src_images = os.path.join(src_dir, 'vlm', 'images') + dst_images = os.path.join(self.output_dir, image_prefix) + + try: + if os.path.exists(src_images): + shutil.copytree(src_images, dst_images) + else: + self.logger.warning(f"Source images dir does not exist: {src_images}") + except Exception as e: + self.logger.warning(f"Failed to copy images from {src_images} to {dst_images}: {e}") + + dataframe.loc[idx, output_qalist_path_key] = output_qalist_path + + storage.write(dataframe) \ No newline at end of file diff --git a/dataflow/operators/pdf2vqa/generate/mineru_to_llm_input_operator.py b/dataflow/operators/pdf2vqa/generate/mineru_to_llm_input_operator.py new file mode 100644 index 00000000..2a51d928 --- /dev/null +++ b/dataflow/operators/pdf2vqa/generate/mineru_to_llm_input_operator.py @@ -0,0 +1,69 @@ +import json +from dataflow.core import OperatorABC +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow.utils.storage import DataFlowStorage + +@OPERATOR_REGISTRY.register() +class MinerU2LLMInputOperator(OperatorABC): + def __init__(self): + pass + + @staticmethod + def get_desc(lang: str = "zh") -> str: + if lang == 'zh': + return ( + "MinerU格式转换为LLM输入格式算子。" + "将MinerU生成的内容列表JSON文件转换为适合LLM处理的格式," + "包括展平列表项并重新编号。" + ) + else: + return ( + "Convert MinerU format to LLM input format operator." + "Transforms the content list JSON file generated by MinerU into a format suitable for LLM processing," + "including flattening list items and re-indexing." + ) + + def _convert_json(self, input_file, output_file): + with open(input_file, 'r') as infile: + data = list(json.load(infile)) + + new_data = [] + id = 0 + for item in data: + item['id'] = id + item.pop('bbox', None) + item.pop('page_idx', None) + if item.get('type','') == 'list': + if item['sub_type'] == 'text': + for idx, list_item in enumerate(item.get('list_items', [])): + new_item = { + 'type': 'text', + 'text': list_item, + 'id': id + idx, + } + new_data.append(new_item) + id += len(item.get('list_items', [])) + else: + new_data.append(item) + id += 1 + + with open(output_file, 'w') as outfile: + json.dump(new_data, outfile, ensure_ascii=False) + + def run(self, storage: DataFlowStorage, + input_markdown_path_key, + output_converted_layout_key, + ): + dataframe = storage.read("dataframe") + + for index, row in dataframe.iterrows(): + input_json_path = row[input_markdown_path_key].replace('.md', '_content_list.json') + converted_path = input_json_path.replace('.json', '_converted.json') + self._convert_json(input_json_path, converted_path) + dataframe.at[index, output_converted_layout_key] = converted_path + + with open(converted_path, 'r') as infile: + data = json.load(infile) + assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}" + + storage.write(dataframe) \ No newline at end of file diff --git a/dataflow/operators/pdf2vqa/generate/qa_merger.py b/dataflow/operators/pdf2vqa/generate/qa_merger.py new file mode 100644 index 00000000..e321dabf --- /dev/null +++ b/dataflow/operators/pdf2vqa/generate/qa_merger.py @@ -0,0 +1,66 @@ +import os +import json +from dataflow.core import OperatorABC +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.pdf2vqa.format_utils import merge_qa_pair, jsonl_to_md + +@OPERATOR_REGISTRY.register() +class QA_Merger(OperatorABC): + def __init__(self, output_dir, strict_title_match=False): + self.output_dir = output_dir + self.strict_title_match = strict_title_match + + @staticmethod + def get_desc(lang: str = "zh") -> str: + if lang == 'zh': + return ( + "QA对合并算子。" + "将问题和答案的QA列表进行合并,生成最终的QA对文件," + "并转换为Markdown格式。" + ) + else: + return ( + "QA pair merging operator." + "Merges question and answer QA lists to generate final QA pair files," + "and converts them to Markdown format." + ) + + def run(self, storage: DataFlowStorage, + input_question_qalist_path_key, + input_answer_qalist_path_key, + input_name_key, + output_merged_qalist_path_key, + output_merged_md_path_key, + output_qa_item_key="qa_item" # 新增:展开后的 QA 内容列名 + ): + dataframe = storage.read("dataframe") + + # 为了能存储 list 对象,先初始化该列为 object 类型 + dataframe[output_qa_item_key] = None + dataframe[output_qa_item_key] = dataframe[output_qa_item_key].astype(object) + + for idx, row in dataframe.iterrows(): + question_qalist_path = row[input_question_qalist_path_key] + answer_qalist_path = row[input_answer_qalist_path_key] + name = row[input_name_key] + + output_merged_qalist_path = os.path.join(self.output_dir, name, "merged_qa_pairs.jsonl") + merge_qa_pair(question_qalist_path, answer_qalist_path, output_merged_qalist_path, strict_title_match=self.strict_title_match) + + output_merged_md_path = os.path.join(self.output_dir, name, "merged_qa_pairs.md") + jsonl_to_md(output_merged_qalist_path, output_merged_md_path) + + qa_pairs = [] + if os.path.exists(output_merged_qalist_path): + with open(output_merged_qalist_path, 'r', encoding='utf-8') as f: + qa_pairs = [json.loads(line) for line in f] + + dataframe.at[idx, output_qa_item_key] = qa_pairs + + dataframe.loc[idx, output_merged_qalist_path_key] = output_merged_qalist_path + dataframe.loc[idx, output_merged_md_path_key] = output_merged_md_path + + dataframe = dataframe.explode(output_qa_item_key).reset_index(drop=True) + + storage.write(dataframe) \ No newline at end of file diff --git a/dataflow/operators/pdf2vqa/generate/vqa_extractor.py b/dataflow/operators/pdf2vqa/generate/vqa_extractor.py deleted file mode 100644 index fcfd8fdf..00000000 --- a/dataflow/operators/pdf2vqa/generate/vqa_extractor.py +++ /dev/null @@ -1,530 +0,0 @@ -import os -import json -import re -import pandas as pd -import tiktoken -import shutil -import torch -from pathlib import Path -from typing import Literal -from dataflow.core import OperatorABC -from dataflow.utils.registry import OPERATOR_REGISTRY -from dataflow.utils.storage import DataFlowStorage -from dataflow import get_logger -from dataflow.core import LLMServingABC -from dataflow.prompts.pdf2vqa import QAExtractPrompt -from dataflow.core.prompt import prompt_restrict -from dataflow.utils.pdf2vqa.format_utils import merge_qa_pair, jsonl_to_md - -@prompt_restrict(QAExtractPrompt) -@OPERATOR_REGISTRY.register() -class VQAExtractor(OperatorABC): - def __init__(self, - llm_serving: LLMServingABC = None, - mineru_backend: Literal["vlm-transformers","vlm-vllm-engine"] = "vlm-transformers", - max_chunk_len: int = 128000,): - self.logger = get_logger() - self.llm_serving = llm_serving - self.prompt_template = QAExtractPrompt() - self.mineru_backend = mineru_backend - self.max_chunk_len = max_chunk_len - - @staticmethod - def get_desc(lang: str = "zh"): - if lang == "zh": - return ( - "该算子用于从试题或图文PDF文档中自动提取问答(VQA)结构化数据。\n\n" - "功能说明:\n" - "- 自动调用 MinerU 模型提取 PDF 文档的版面与内容布局信息。\n" - "- 支持题目与答案的分离提取或交错(interleaved)模式处理。\n" - "- 基于 LLM 生成章节结构化问答( 标签格式)。\n" - "- 自动进行内容清洗、图片路径替换与问答重建。\n" - "- 支持结果过滤、合并及 Markdown 文档转换。\n\n" - "输入要求:\n" - "- DataFrame 中需包含 PDF 文件路径列,可为 question_pdf_path/answer_pdf_path 或 pdf_path。\n\n" - "初始化参数:\n" - "- llm_serving: LLM 推理服务实例,用于生成问答。\n" - "- mineru_backend: MinerU 后端类型(可选值:\"vlm-transformers\" 或 \"vlm-vllm-engine\")。\n" - "- max_chunk_len: 单批次最大token数量(默认128000)。\n\n" - "运行参数(run):\n" - "- input_question_pdf_path_key: 题目PDF路径列名(默认:\"question_pdf_path\")。\n" - "- input_answer_pdf_path_key: 答案PDF路径列名(默认:\"answer_pdf_path\")。\n" - "- input_pdf_path_key: 交错模式下的PDF路径列名(默认:\"pdf_path\")。\n" - "- input_subject_key: 学科类别列名(默认:\"subject\")。\n" - "- output_dir_key: 输出目录列名(默认:\"output_dir\")。\n" - "- output_jsonl_key: 输出JSONL路径列名(默认:\"output_jsonl_path\")。\n" - "- output_default_dir: 默认输出目录(默认:\"../vqa_output\")。\n\n" - "输出:\n" - "- 在 DataFrame 中新增一列,记录生成的VQA结构化问答JSONL文件路径。\n" - "- 同时生成过滤后的Markdown文档和对应图片资源文件夹。" - ) - elif lang == "en": - return ( - "This operator extracts structured Visual Question Answering (VQA) data from exam or multimodal PDF documents.\n\n" - "Functionality:\n" - "- Automatically uses MinerU models to extract PDF layout and textual content.\n" - "- Supports both separate (question/answer) and interleaved PDF processing modes.\n" - "- Generates structured chapter-based QA pairs using an LLM (, tags).\n" - "- Cleans and reconstructs QA content with proper image references.\n" - "- Filters, merges, and converts the output into Markdown format.\n\n" - "Input Requirements:\n" - "- The input DataFrame must contain PDF path columns: either question_pdf_path/answer_pdf_path or pdf_path.\n\n" - "Initialization Parameters:\n" - "- llm_serving: Instance of LLM inference service used for QA generation.\n" - "- mineru_backend: Backend type for MinerU ('vlm-transformers' or 'vlm-vllm-engine').\n" - "- max_chunk_len: Maximum number of tokens per batch (default: 128000).\n\n" - "Run Parameters:\n" - "- input_question_pdf_path_key: Column name for question PDF path (default: 'question_pdf_path').\n" - "- input_answer_pdf_path_key: Column name for answer PDF path (default: 'answer_pdf_path').\n" - "- input_pdf_path_key: Column name for interleaved PDF path (default: 'pdf_path').\n" - "- input_subject_key: Column name for subject type (default: 'subject').\n" - "- output_dir_key: Column name for output directory (default: 'output_dir').\n" - "- output_jsonl_key: Column name for output JSONL file path (default: 'output_jsonl_path').\n" - "- output_default_dir: Default output directory (default: '../vqa_output').\n\n" - "Output:\n" - "- Adds a new column to the DataFrame containing paths to generated structured VQA JSONL files.\n" - "- Also produces filtered Markdown documents and associated image folders." - ) - else: - return "VQAExtractor extracts structured VQA data from PDF documents and outputs filtered JSONL and Markdown files." - - - def _convert_json(self, input_file, output_file): - with open(input_file, 'r') as infile: - data = list(json.load(infile)) - - new_data = [] - id = 0 - for item in data: - item['id'] = id - item.pop('bbox', None) - item.pop('page_idx', None) - if item.get('type','') == 'list': - if item['sub_type'] == 'text': - for idx, list_item in enumerate(item.get('list_items', [])): - new_item = { - 'type': 'text', - 'text': list_item, - 'id': id + idx, - } - new_data.append(new_item) - id += len(item.get('list_items', [])) - else: - new_data.append(item) - id += 1 - - with open(output_file, 'w') as outfile: - json.dump(new_data, outfile, ensure_ascii=False) - - def _count_tokens(self, text: str) -> int: - enc = tiktoken.get_encoding("cl100k_base") - return len(enc.encode(text)) - - def _id_to_text(self, input_ids, input_json, image_prefix="images"): - texts = [] - id_list = input_ids.replace(' ', '').split(',') - for id in id_list: - try: - int(id) - except: - continue - if int(id) < len(input_json): - try: - item = input_json[int(id)] - except: - continue - if 'text' in item: - texts.append(item['text']) - elif 'img_path' in item: - try: - img_path = item.get('img_path', '') - img_name = os.path.basename(img_path) - new_path = f"{image_prefix}/{img_name}" - texts.append(f"![{' '.join(item.get('image_caption','image'))}]({new_path})") - except: - pass - elif item.get('type','') == 'list': - if item['sub_type'] == 'text': - try: - texts.append(input_json[int(id)]['list_items'].pop(0)) - except: - pass - return '\n'.join(texts) - - def _extract_doc_layout(self, input_pdf_file_path: str, output_folder: str, mineru_backend: Literal["vlm-transformers","vlm-vllm-engine"] = "vlm-transformers"): - """提取 PDF 的布局信息(合并自 VQAExtractDocLayoutMinerU)""" - try: - import mineru - from mineru.cli.client import main as mineru_main - except ImportError: - raise Exception( - """ - MinerU is not installed in this environment yet. - Please refer to https://github.com/opendatalab/mineru to install. - Or you can just execute 'pip install mineru[pipeline]' and 'mineru-models-download' to fix this error. - Please make sure you have GPU on your machine. - """ - ) - try: - from pypdf import PdfReader, PdfWriter, PageObject - except ImportError: - raise Exception( - """ - pypdf is not installed in this environment yet. - Please use pip install pypdf. - """ - ) - try: - from reportlab.pdfgen import canvas - except ImportError: - raise Exception( - """ - reportlab is not installed in this environment yet. - Please use pip install reportlab. - """ - ) - - os.environ['MINERU_MODEL_SOURCE'] = "local" - - MinerU_Version = {"pipeline": "auto", "vlm-transformers": "vlm", "vlm-vllm-engine": "vlm"} - - if mineru_backend == "pipeline": - raise ValueError("The 'pipeline' backend is not supported due to its incompatible output format. Please use 'vlm-transformers' or 'vlm-vllm-engine' instead.") - - raw_file = Path(input_pdf_file_path) - pdf_name = raw_file.stem - intermediate_dir = output_folder - args = [ - "-p", str(raw_file), - "-o", str(intermediate_dir), - "-b", mineru_backend, - "--source", "local" - ] - if mineru_backend == "vlm-vllm-engine": - assert torch.cuda.is_available(), "MinerU vlm-vllm-engine backend requires GPU support." - args += ["--tensor-parallel-size", "2" if torch.cuda.device_count() >= 2 else "1"] - - try: - mineru_main(args) - except SystemExit as e: - if e.code != 0: - raise RuntimeError(f"MinerU execution failed with exit code: {e.code}") - - output_json_file = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend], f"{pdf_name}_content_list.json") - output_layout_file = os.path.join(intermediate_dir, pdf_name, MinerU_Version[mineru_backend], f"{pdf_name}_layout.pdf") - return output_json_file, output_layout_file - - def _convert_response(self, input_response, input_json_path, image_prefix="images"): - qa_list = [] - with open(input_json_path, 'r') as infile: - input_json = list(json.load(infile)) - # 提取title - for chapter_block in re.findall(r'(.*?)', input_response, flags=re.DOTALL): - title = re.search(r'(.*?)', chapter_block, flags=re.DOTALL) - if title: - chapter_title = self._id_to_text(title.group(1).strip(), input_json, image_prefix) - else: - chapter_title = "" - # 找出所有 qa_pair 块 - for pair in re.findall(r'(.*?)', chapter_block, flags=re.DOTALL): - # 提取 question 部分 - q_match = re.search(r'(.*?)', pair, flags=re.DOTALL) - # 提取 answer 部分 - a_match = re.search(r'(.*?)', pair, flags=re.DOTALL) - # 提取solution部分 - s_match = re.search(r'(.*?)', pair, flags=re.DOTALL) - # 提取label - label_match = re.search(r'', pair, flags=re.DOTALL) - if not ((q_match and label_match) or (a_match and label_match) or (s_match and label_match)): - continue - label = label_match.group(1).strip() - qa_list.append({ - 'question': self._id_to_text(q_match.group(1).strip(), input_json, image_prefix) if q_match else "", - 'answer': a_match.group(1).strip() if a_match else "", - 'solution': self._id_to_text(s_match.group(1).strip(), input_json, image_prefix) if s_match else "", - 'label': label, - 'chapter_title': chapter_title - }) - return qa_list - - def run(self, storage: DataFlowStorage, - input_question_pdf_path_key: str = "question_pdf_path", - input_answer_pdf_path_key: str = "answer_pdf_path", - input_pdf_path_key: str = "pdf_path", # 支持 interleaved 模式的单一 pdf_path - input_subject_key: str = "subject", - output_dir_key: str = "output_dir", - output_jsonl_key: str = "output_jsonl_path", - output_default_dir: str = "../vqa_output") -> list: - dataframe = storage.read("dataframe") - - # 支持两种输入格式:question_pdf_path/answer_pdf_path 或 pdf_path - if input_question_pdf_path_key not in dataframe.columns and input_pdf_path_key not in dataframe.columns: - raise ValueError(f"Column '{input_question_pdf_path_key}' or '{input_pdf_path_key}' not found in dataframe") - - # ========== Stage 1: 预处理(任务扩展 + Layout 提取) ========== - expanded_rows = [] - - for idx, row in dataframe.iterrows(): - # 优先使用 question_pdf_path,如果没有则使用 pdf_path(interleaved 模式) - if input_question_pdf_path_key in dataframe.columns: - question_pdf_path = row[input_question_pdf_path_key] - answer_pdf_path = row.get(input_answer_pdf_path_key, question_pdf_path) - else: - # interleaved 模式:使用同一个 pdf_path - question_pdf_path = row[input_pdf_path_key] - answer_pdf_path = question_pdf_path - - subject = row.get(input_subject_key, "math") - output_root = row.get(output_dir_key, output_default_dir) - interleaved = (question_pdf_path == answer_pdf_path) - - os.makedirs(output_root, exist_ok=True) - - # Question task - q_outdir = os.path.join(output_root, "question") - os.makedirs(q_outdir, exist_ok=True) - - # Layout 提取 - q_json_path, _ = self._extract_doc_layout( - input_pdf_file_path=question_pdf_path, - output_folder=q_outdir, - mineru_backend=self.mineru_backend - ) - - expanded_rows.append({ - "pdf_path": question_pdf_path, - "mode": "question", - "interleaved": interleaved, - "subject": subject, - "output_dir": q_outdir, - "output_root": output_root, - "json_path": q_json_path - }) - - # Answer task (if not interleaved) - if not interleaved: - a_outdir = os.path.join(output_root, "answer") - os.makedirs(a_outdir, exist_ok=True) - - # Layout 提取 - a_json_path, _ = self._extract_doc_layout( - input_pdf_file_path=answer_pdf_path, - output_folder=a_outdir, - mineru_backend=self.mineru_backend - ) - - expanded_rows.append({ - "pdf_path": answer_pdf_path, - "mode": "answer", - "interleaved": interleaved, - "subject": subject, - "output_dir": a_outdir, - "output_root": output_root, - "json_path": a_json_path - }) - - # ========== Stage 2: QA 提取 ========== - json_paths = [row["json_path"] for row in expanded_rows] - subjects = [row["subject"] for row in expanded_rows] - - user_inputs = [] - split_metadata = [] - - for idx, input_json_path in enumerate(json_paths): - subject = subjects[idx] if idx < len(subjects) else subjects[0] if subjects else "math" - system_prompt = self.prompt_template.build_prompt(subject) - system_prompt_len = self._count_tokens(system_prompt) - - converted_path = input_json_path.replace('.json', '_converted.json') - self._convert_json(input_json_path, converted_path) - - with open(converted_path, 'r') as infile: - data = json.load(infile) - assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}" - - # 分段处理 - current_chunk, current_len = [], system_prompt_len - chunks = [] - - for item in data: - text = json.dumps(item, ensure_ascii=False) - item_len = self._count_tokens(text) - if current_len + item_len > self.max_chunk_len and current_chunk: - chunks.append(current_chunk) - current_chunk, current_len = [], system_prompt_len - current_chunk.append(item) - current_len += item_len - - if current_chunk: - chunks.append(current_chunk) - - split_metadata.append(len(chunks)) - - for chunk in chunks: - user_inputs.append({ - 'user_input': json.dumps(chunk, ensure_ascii=False), - 'system_prompt': system_prompt - }) - - # 批量生成 - responses = [None] * len(user_inputs) - current_batch = [] - current_batch_indices = [] - current_system_prompt = None - - for idx, item in enumerate(user_inputs): - user_input = item['user_input'] - system_prompt = item['system_prompt'] - - if current_system_prompt is None: - current_system_prompt = system_prompt - current_batch = [user_input] - current_batch_indices = [idx] - elif system_prompt == current_system_prompt: - current_batch.append(user_input) - current_batch_indices.append(idx) - else: - # 处理当前批次 - batch_responses = self.llm_serving.generate_from_input(user_inputs=current_batch, system_prompt=current_system_prompt) - for batch_idx, resp in zip(current_batch_indices, batch_responses): - responses[batch_idx] = resp - # 开始新批次 - current_system_prompt = system_prompt - current_batch = [user_input] - current_batch_indices = [idx] - - # 处理最后一批 - if current_batch: - batch_responses = self.llm_serving.generate_from_input(user_inputs=current_batch, system_prompt=current_system_prompt) - for batch_idx, resp in zip(current_batch_indices, batch_responses): - responses[batch_idx] = resp - - # 按 split_metadata 还原 - recombined_responses = [] - idx = 0 - for num_chunks in split_metadata: - merged_text = "\n".join(responses[idx: idx + num_chunks]) - recombined_responses.append(merged_text) - idx += num_chunks - - # ========== Stage 3: 后处理(Response 转换 + 合并和过滤) ========== - # Response 转换 - qa_lists = [] - for idx, (response, row) in enumerate(zip(recombined_responses, expanded_rows)): - json_path = row["json_path"] - output_dir = row["output_dir"] - mode = row["mode"] - output_root = row["output_root"] - - image_prefix = f"{mode}_images" - converted_json_path = json_path.replace('.json', '_converted.json') - qa_list = self._convert_response(response, converted_json_path, image_prefix) - - # 复制图片 - src_dir = os.path.join(output_dir, Path(json_path).stem).replace('_content_list','') - src_images = os.path.join(src_dir, 'vlm', 'images') - dst_images = os.path.join(output_root, image_prefix) - - try: - if os.path.exists(src_images): - if os.path.exists(dst_images): - shutil.rmtree(dst_images) - shutil.copytree(src_images, dst_images) - else: - self.logger.warning(f"Source images dir does not exist: {src_images}") - except Exception as e: - self.logger.warning(f"Failed to copy images from {src_images} to {dst_images}: {e}") - - qa_lists.append(qa_list) - - # 按 output_root 分组处理合并和过滤 - output_groups = {} - for idx, (qa_list, row) in enumerate(zip(qa_lists, expanded_rows)): - output_root = row["output_root"] - mode = row["mode"] - interleaved = row["interleaved"] - output_dir = row["output_dir"] - - if output_root not in output_groups: - output_groups[output_root] = { - "question": None, - "answer": None, - "interleaved": interleaved - } - - if mode == "question": - output_groups[output_root]["question"] = (qa_list, output_dir) - elif mode == "answer": - output_groups[output_root]["answer"] = (qa_list, output_dir) - - # 处理每个 output_root - result_paths_dict = {} - for output_root, group_info in output_groups.items(): - q_qa_list, q_output_dir = group_info["question"] if group_info["question"] else (None, None) - a_qa_list, a_output_dir = group_info["answer"] if group_info["answer"] else (None, None) - interleaved = group_info["interleaved"] - - # 写入 question jsonl - q_jsonl_path = os.path.join(output_root, "vqa_extracted_questions.jsonl") - if q_qa_list: - with open(q_jsonl_path, 'w', encoding='utf-8') as f: - for item in q_qa_list: - f.write(json.dumps(item, ensure_ascii=False) + '\n') - - # 写入 answer jsonl(如果不是 interleaved) - a_jsonl_path = None - if not interleaved and a_qa_list: - a_jsonl_path = os.path.join(output_root, "vqa_extracted_answers.jsonl") - with open(a_jsonl_path, 'w', encoding='utf-8') as f: - for item in a_qa_list: - f.write(json.dumps(item, ensure_ascii=False) + '\n') - - # 合并 - merged_jsonl = os.path.join(output_root, "vqa_merged_qa_pairs.jsonl") - if not interleaved and a_jsonl_path: - merge_qa_pair(q_jsonl_path, a_jsonl_path, merged_jsonl) - else: - os.system(f"cp {q_jsonl_path} {merged_jsonl}") - - # 过滤 - filtered_items = [] - total_count = 0 - with open(merged_jsonl, 'r', encoding='utf-8') as f: - for line in f: - total_count += 1 - item = json.loads(line) - if item.get('question','').strip() and (item.get('answer','').strip() or item.get('solution','').strip()): - filtered_items.append(item) - - self.logger.info(f"Before filter: {total_count}, After filter: {len(filtered_items)}") - - filtered_jsonl = os.path.join(output_root, "vqa_filtered_qa_pairs.jsonl") - with open(filtered_jsonl, 'w', encoding='utf-8') as f: - for item in filtered_items: - f.write(json.dumps(item, ensure_ascii=False) + '\n') - - # 转换为 markdown - md_output = os.path.join(output_root, "vqa_filtered_qa_pairs.md") - jsonl_to_md(filtered_jsonl, md_output) - - result_paths_dict[output_root] = filtered_jsonl - - # 为原始 dataframe 的每一行分配结果路径 - result_paths = [] - for idx, row in dataframe.iterrows(): - if input_question_pdf_path_key in dataframe.columns: - question_pdf_path = row[input_question_pdf_path_key] - answer_pdf_path = row.get(input_answer_pdf_path_key, question_pdf_path) - else: - question_pdf_path = row[input_pdf_path_key] - answer_pdf_path = question_pdf_path - - output_root = row.get(output_dir_key, output_default_dir) - result_paths.append(result_paths_dict.get(output_root)) - - dataframe[output_jsonl_key] = result_paths - output_file = storage.write(dataframe) - self.logger.info(f"VQA extraction complete. Results saved to {output_file}") - - return [output_jsonl_key,] - diff --git a/dataflow/prompts/pdf2vqa.py b/dataflow/prompts/pdf2vqa.py index 42a6c3a3..3609a029 100644 --- a/dataflow/prompts/pdf2vqa.py +++ b/dataflow/prompts/pdf2vqa.py @@ -92,9 +92,9 @@ class QAExtractPrompt(PromptABC): def __init__(self): pass - def build_prompt(self, subject: str = "math") -> str: + def build_prompt(self) -> str: PROMPT = f""" - You are an expert in {subject}. You are given a json file. Your task is to segment the content, insert images tags, and extract labels: + You are an expert in answer college-level questions. You are given a json file. Your task is to segment the content, insert images tags, and extract labels: 1. Every json item has an "id" field. Your main task is to output this field. 2. You need to segment the content into multiple ``…`` blocks, each containing a question and its corresponding answer with solution. 3. If the problem or answer/solution is not complete, omit them. An answer/solution should be considered complete as long as either the answer or solution exists. @@ -120,7 +120,7 @@ def build_prompt(self, subject: str = "math") -> str: - Always enclose qa pairs in a ``…`` block, where MAIN_TITLE_ID is the id of the chapter title or section title. - Normally, chapter/section titles appear before the questions/answers in an independent json item. - There could be multiple ``…`` blocks if multiple chapters/sections exist. -- **Any title followed by a question/answer whose label/number is not 1, or title with a score, should NOT be extracted.** +- **Any title followed by a question/answer whose label/number is not 1, or title with a score such as "一、选择题(每题1分,共10分)", should NOT be extracted.** - Do not use nested titles. - Leave the title blank if there is no chapter title. ** About figures/diagrams ** diff --git a/dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py b/dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py index e186a2e8..929f3bbc 100644 --- a/dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py +++ b/dataflow/statics/pipelines/api_pipelines/pdf_vqa_extract_pipeline.py @@ -1,48 +1,107 @@ -import os -import sys +from dataflow.operators.knowledge_cleaning import FileOrURLToMarkdownConverterBatch from dataflow.serving import APILLMServing_request from dataflow.utils.storage import FileStorage -from dataflow.operators.pdf2vqa import VQAExtractor +from dataflow.operators.pdf2vqa import MinerU2LLMInputOperator, LLMOutputParser, QA_Merger +from dataflow.operators.core_text import ChunkedPromptedGenerator -class VQA_extract_optimized_pipeline: +from dataflow.pipeline import PipelineABC +from dataflow.prompts.pdf2vqa import QAExtractPrompt + +class PDF_VQA_extract_optimized_pipeline(PipelineABC): def __init__(self): + super().__init__() self.storage = FileStorage( - first_entry_file_name="../example_data/PDF2VQAPipeline/vqa_extract_test.jsonl", + first_entry_file_name="./example_data/PDF2VQAPipeline/vqa_extract_test.jsonl", cache_path="./cache", file_name_prefix="vqa", cache_type="jsonl", ) self.llm_serving = APILLMServing_request( - api_url="https://generativelanguage.googleapis.com/v1beta/openai/chat/completions", + api_url="http://123.129.219.111:3000/v1/chat/completions", key_name_of_api_key="DF_API_KEY", model_name="gemini-2.5-pro", max_workers=100, ) - self.vqa_extractor = VQAExtractor( + self.vqa_extract_prompt = QAExtractPrompt() + + self.mineru_executor = FileOrURLToMarkdownConverterBatch(intermediate_dir = "intermediate", mineru_backend="vlm-vllm-engine") + self.input_formatter = MinerU2LLMInputOperator() + self.vqa_extractor = ChunkedPromptedGenerator( llm_serving=self.llm_serving, - mineru_backend='vlm-vllm-engine', - max_chunk_len=128000 + system_prompt = self.vqa_extract_prompt.build_prompt(), + max_chunk_len=128000, ) - + self.llm_output_question_parser = LLMOutputParser(mode="question", output_dir="./cache", intermediate_dir="intermediate") + self.llm_output_answer_parser = LLMOutputParser(mode="answer", output_dir="./cache", intermediate_dir="intermediate") + self.qa_merger = QA_Merger(output_dir="./cache", strict_title_match=False) def forward(self): - # 单一算子:包含预处理、QA提取、后处理的所有功能 + # 目前的处理逻辑是:MinerU处理问题-MinerU处理答案-格式化问题文本-格式化答案文本-问题文本输入LLM-答案文本输入LLM-解析问题输出-解析答案输出-合并问答对 + # 由于问答对可能来自同一份pdf,也有可能来自不同pdf,而dataflow目前不支持分支,因此这里只能将question和answer的pdf都进行一次处理, + # 即使是同一份pdf也会被处理两次,最后再合并问答对。 + # 未来会再思考如何优化这个流程,避免重复处理同一份pdf,提升性能。 + + self.mineru_executor.run( + storage=self.storage.step(), + input_key="question_pdf_path", + output_key="question_markdown_path", + ) + self.mineru_executor.run( + storage=self.storage.step(), + input_key="answer_pdf_path", + output_key="answer_markdown_path", + ) + self.input_formatter.run( + storage=self.storage.step(), + input_markdown_path_key="question_markdown_path", + output_converted_layout_key="converted_question_layout_path", + ) + self.input_formatter.run( + storage=self.storage.step(), + input_markdown_path_key="answer_markdown_path", + output_converted_layout_key="converted_answer_layout_path", + ) self.vqa_extractor.run( storage=self.storage.step(), - input_question_pdf_path_key="question_pdf_path", - input_answer_pdf_path_key="answer_pdf_path", - input_pdf_path_key="pdf_path", # 支持 interleaved 模式 - input_subject_key="subject", - output_dir_key="output_dir", - output_jsonl_key="output_jsonl_path", + input_path_key="converted_question_layout_path", + output_path_key="vqa_extracted_questions_path", + ) + self.vqa_extractor.run( + storage=self.storage.step(), + input_path_key="converted_answer_layout_path", + output_path_key="vqa_extracted_answers_path", + ) + self.llm_output_question_parser.run( + storage=self.storage.step(), + input_response_path_key="vqa_extracted_questions_path", + input_converted_layout_path_key="converted_question_layout_path", + input_name_key="name", + output_qalist_path_key="extracted_questions_path", + ) + self.llm_output_answer_parser.run( + storage=self.storage.step(), + input_response_path_key="vqa_extracted_answers_path", + input_converted_layout_path_key="converted_answer_layout_path", + input_name_key="name", + output_qalist_path_key="extracted_answers_path", + ) + self.qa_merger.run( + storage=self.storage.step(), + input_question_qalist_path_key="extracted_questions_path", + input_answer_qalist_path_key="extracted_answers_path", + input_name_key="name", + output_merged_qalist_path_key="output_merged_qalist_path", + output_merged_md_path_key="output_merged_md_path", + output_qa_item_key="qa_pair", ) if __name__ == "__main__": - # jsonl中每一行包含question_pdf_path, answer_pdf_path, subject (math, physics, chemistry, ...), output_dir + # jsonl中每一行包含question_pdf_path, answer_pdf_path, name (math1, math2, physics1, chemistry1, ...) # 如果question和answer在同一份pdf中,请将question_pdf_path和answer_pdf_path设置为相同的路径,会自动切换为interleaved模式 - pipeline = VQA_extract_optimized_pipeline() - pipeline.forward() + pipeline = PDF_VQA_extract_optimized_pipeline() + pipeline.compile() + pipeline.forward() \ No newline at end of file diff --git a/dataflow/utils/pdf2vqa/format_utils.py b/dataflow/utils/pdf2vqa/format_utils.py index 94294d15..c1b8138e 100644 --- a/dataflow/utils/pdf2vqa/format_utils.py +++ b/dataflow/utils/pdf2vqa/format_utils.py @@ -1,11 +1,29 @@ import json import re -def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl): +def refine_title(title: str, strict_title_match=False): + # TODO : 这里可能需要更复杂的title清洗逻辑 + # 删除title中的空格与换行符 + title = re.sub(r'\s+', '', title) + if not strict_title_match: + try: + # 优先提取阿拉伯数字章节编号(如1.1,2等) + new_title = re.search(r"\d+\.\d+|\d+", title).group() + except: + try: + # 其次提取中文数字章节编号(如六、二十四等) + new_title = re.search(r'[一二三四五六七八九零十百]+', title).group() + except: + new_title = title + title = new_title + return title + +def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl, strict_title_match=False): + already_complete_count = 0 with open(question_jsonl, 'r', encoding='utf-8') as q_file, open(answer_jsonl, 'r', encoding='utf-8') as a_file, open(output_jsonl, 'w', encoding='utf-8') as out_file: chapter_id = 0 chapter_title = "" - label = 1000000 + label = float('inf') questions = {} answers = {} for line in q_file: @@ -29,13 +47,11 @@ def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl): # 如果题号增加,章节标题却发生变化,说明可能错误提取了子标题。因此继续使用之前的章节标题。 data["chapter_title"] = chapter_title label = data["label"] - data["chapter_id"] = chapter_id - # TODO : 这里可能需要更复杂的title清洗逻辑 - # 删除title中的空格与换行符 - data["chapter_title"] = re.sub(r'\s+', '', data["chapter_title"]) - if data['label'] > 0 and data["chapter_title"]: + data["chapter_title"] = refine_title(data["chapter_title"], strict_title_match) + if data['label'] > 0: # 已经完整的题目直接写入out_file if data["answer"] or data["solution"]: + already_complete_count += 1 qa_pair = { "question_chapter_title": data["chapter_title"], "answer_chapter_title": data["chapter_title"], @@ -51,7 +67,7 @@ def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl): chapter_id = 0 chapter_title = "" - label = 1000000 + label = float('inf') for line in a_file: data = json.loads(line) label_match = re.search(r'\d+', data["label"]) @@ -73,10 +89,7 @@ def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl): # 如果题号增加,章节标题却发生变化,说明可能错误提取了子标题。因此继续使用之前的章节标题。 data["chapter_title"] = chapter_title label = data["label"] - data["chapter_id"] = chapter_id - # TODO : 这里可能需要更复杂的title清洗逻辑 - # 删除title中的空格与换行符 - data["chapter_title"] = re.sub(r'\s+', '', data["chapter_title"]) + data["chapter_title"] = refine_title(data["chapter_title"], strict_title_match) # 动态更新,防止错误的重复label覆盖掉之前的solution或answer if data['label'] > 0: if not answers.get((data["chapter_title"], data['label'])): @@ -99,7 +112,7 @@ def merge_qa_pair(question_jsonl, answer_jsonl, output_jsonl): } out_file.write(json.dumps(qa_pair, ensure_ascii=False) + '\n') - print(f"Merged QA pairs: {len(questions.keys() & answers.keys())}") + print(f"Merged QA pairs: {len(questions.keys() & answers.keys()) + already_complete_count}") def jsonl_to_md(jsonl_file, md_file): with open(jsonl_file, 'r', encoding='utf-8') as in_file, open(md_file, 'w', encoding='utf-8') as out_file: