-
Notifications
You must be signed in to change notification settings - Fork 169
PDF2VQA 重构 #443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
PDF2VQA 重构 #443
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
b8e5d21
[pdf2vqa] 现在没有识别出任何问题时会输出空文件,而不是报错。同时改进了问答对的章节匹配逻辑
31f5841
[pdf2vqa] 现在如果mineru结果已经存在,可以跳过直接跑llm。修正example文件路径
fatty-belly 13f6814
[PDF2VQA] 大幅度的重构,复用已有算子
fatty-belly 8427fbd
[pdf2vqa] 为chunked_prompted_generator设置单独文件。添加了一些注释
fatty-belly 49b8a82
[pdf2vqa] 一个文件一个算子
fatty-belly File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| {"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "subject": "math", "output_dir": "../vqa_output_test/math1"} | ||
| {"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "subject": "math", "output_dir": "../vqa_output_test/math2"} | ||
| {"question_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/questionextract_test.pdf", "name": "math1"} | ||
| {"question_pdf_path": "./example_data/PDF2VQAPipeline/math_question.pdf", "answer_pdf_path": "./example_data/PDF2VQAPipeline/math_answer.pdf", "name": "math2"} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
147 changes: 147 additions & 0 deletions
147
dataflow/operators/core_text/generate/chunked_prompted_generator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| import pandas as pd | ||
| import tiktoken | ||
| from dataflow.utils.registry import OPERATOR_REGISTRY | ||
| from dataflow import get_logger | ||
| from pathlib import Path | ||
|
|
||
| from dataflow.utils.storage import DataFlowStorage | ||
| from dataflow.core import OperatorABC | ||
| from dataflow.core import LLMServingABC | ||
|
|
||
| @OPERATOR_REGISTRY.register() | ||
| class ChunkedPromptedGenerator(OperatorABC): | ||
| """ | ||
| 基于Prompt的生成算子,支持自动chunk输入。 | ||
| - 使用tiktoken或HuggingFace的AutoTokenizer计算token数量; | ||
| - 若输入超过max_chunk_len,采用递归二分法切分; | ||
| - 从指定输入文件路径读取内容,生成结果保存至指定输出文件路径; | ||
| - 生成结果是以separator拼接的字符串。 | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| llm_serving: LLMServingABC, | ||
| system_prompt: str = "You are a helpful agent.", | ||
| json_schema: dict = None, | ||
| max_chunk_len: int = 128000, | ||
| enc = tiktoken.get_encoding("cl100k_base"), # 支持len(enc.encode(text))的tokenizer都可以,比如tiktoken或HuggingFace的AutoTokenizer | ||
| seperator: str = "\n", | ||
| ): | ||
| self.logger = get_logger() | ||
| self.llm_serving = llm_serving | ||
| self.system_prompt = system_prompt | ||
| self.json_schema = json_schema | ||
| self.max_chunk_len = max_chunk_len | ||
| self.enc = enc | ||
| self.separator = seperator | ||
|
|
||
| @staticmethod | ||
| def get_desc(lang: str = "zh"): | ||
| if lang == "zh": | ||
| return ( | ||
| "基于提示词的生成算子,支持长文本自动分chunk。" | ||
| "采用递归二分方式进行chunk切分,确保每段不超过max_chunk_len tokens。" | ||
| "从给定的输入文件路径读取内容,生成结果保存至指定输出文件路径。" | ||
| "多个生成结果以separator拼接成最终输出字符串。" | ||
| "输入参数:\n" | ||
| "- llm_serving:LLM服务对象,需实现LLMServingABC接口\n" | ||
| "- system_prompt:系统提示词,定义模型行为,默认为'You are a helpful agent.'\n" | ||
| "- max_chunk_len:单个chunk的最大token长度,默认为128000\n" | ||
| "- input_path_key:输入文件路径字段名,默认为'input_path'\n" | ||
| "- output_path_key:输出文件路径字段名,默认为'output_path'\n" | ||
| "- json_schema:可选,生成结果的JSON Schema约束\n" | ||
| "- enc:用于token计算的编码器,需要实现encode方法,默认为tiktoken的cl100k_base编码器,也可以使用HuggingFace 的 AutoTokenizer\n" | ||
| "- separator:chunk结果拼接分隔符,默认为换行符\n" | ||
| ) | ||
| else: | ||
| return ( | ||
| "Prompt-based generator with recursive chunk splitting." | ||
| "Splits long text inputs into chunks using recursive bisection to ensure each chunk does not exceed max_chunk_len tokens." | ||
| "Reads content from specified input file paths and saves generated results to designated output file paths." | ||
| "Multiple generated results are joined as a string using the specified separator." | ||
| "Input Parameters:\n" | ||
| "- llm_serving: LLM serving object implementing LLMServingABC interface\n" | ||
| "- system_prompt: System prompt to define model behavior, default is 'You are a helpful agent.'\n" | ||
| "- max_chunk_len: Maximum token length per chunk, default is 128000\n" | ||
| "- input_path_key: Field name for input file path, default is 'input_path'\n" | ||
| "- output_path_key: Field name for output file path, default is 'output_path'\n" | ||
| "- json_schema: Optional JSON Schema constraint for generated results\n" | ||
| "- enc: Encoder for token counting, default is tiktoken's cl100k_base encoder; can also use HuggingFace's AutoTokenizer\n" | ||
| "- separator: Separator for chunk results, default is newline character\n" | ||
| ) | ||
|
|
||
| # === token计算 === | ||
| def _count_tokens(self, text: str) -> int: | ||
| return len(self.enc.encode(text)) | ||
|
|
||
| # === 递归二分分chunk === | ||
| def _split_recursive(self, text: str) -> list[str]: | ||
| """递归地将文本拆分为不超过max_chunk_len的多个chunk""" | ||
| token_len = self._count_tokens(text) | ||
| if token_len <= self.max_chunk_len: | ||
| return [text] | ||
| else: | ||
| mid = len(text) // 2 | ||
| left, right = text[:mid], text[mid:] | ||
| return self._split_recursive(left) + self._split_recursive(right) | ||
|
|
||
| def run( | ||
| self, | ||
| storage: DataFlowStorage, | ||
| input_path_key, | ||
| output_path_key, | ||
| ): | ||
| self.logger.info("Running ChunkedPromptedGenerator...") | ||
| dataframe = storage.read("dataframe") | ||
| self.logger.info(f"Loaded DataFrame with {len(dataframe)} rows.") | ||
|
|
||
| all_generated_results = [] | ||
|
|
||
| all_llm_inputs = [] | ||
| row_chunk_map = [] # 记录每个row对应的chunk数量 | ||
|
|
||
| # === 先收集所有chunk === | ||
| for i, row in dataframe.iterrows(): | ||
| raw_content = Path(row[input_path_key]).read_text(encoding='utf-8') | ||
|
|
||
| chunks = self._split_recursive(raw_content) | ||
| self.logger.info(f"Row {i}: split into {len(chunks)} chunks") | ||
|
|
||
| system_prompt = self.system_prompt + "\n" | ||
| llm_inputs = [system_prompt + chunk for chunk in chunks] | ||
| all_llm_inputs.extend(llm_inputs) | ||
| row_chunk_map.append(len(chunks)) | ||
|
|
||
| # === 一次性并发调用 === | ||
| self.logger.info(f"Total {len(all_llm_inputs)} chunks to generate") | ||
|
|
||
| try: | ||
| if self.json_schema: | ||
| all_responses = self.llm_serving.generate_from_input( | ||
| all_llm_inputs, json_schema=self.json_schema | ||
| ) | ||
| else: | ||
| all_responses = self.llm_serving.generate_from_input(all_llm_inputs) | ||
| except Exception as e: | ||
| self.logger.error(f"Global generation failed: {e}") | ||
| all_generated_results = [[] for _ in range(len(dataframe))] | ||
| else: | ||
| # === 按row重新划分responses === | ||
| all_generated_results = [] | ||
| idx = 0 | ||
| for num_chunks in row_chunk_map: | ||
| if num_chunks == 0: | ||
| all_generated_results.append([]) | ||
| else: | ||
| all_generated_results.append(all_responses[idx:idx + num_chunks]) | ||
| idx += num_chunks | ||
|
|
||
| for (i, row), gen_results in zip(dataframe.iterrows(), all_generated_results): | ||
| output_path = row[input_path_key].split('.')[0] + '_llm_output.txt' | ||
| with open(output_path, 'w', encoding='utf-8') as f: | ||
| f.write(self.separator.join(gen_results)) | ||
| dataframe.at[i, output_path_key] = output_path | ||
|
|
||
| output_file = storage.write(dataframe) | ||
| self.logger.info(f"Generation complete. Output saved to {output_file}") | ||
| return output_path_key |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
141 changes: 141 additions & 0 deletions
141
dataflow/operators/pdf2vqa/generate/llm_output_parser.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| import os | ||
| import json | ||
| import re | ||
| import shutil | ||
| from pathlib import Path | ||
| from typing import Literal | ||
| from dataflow.core import OperatorABC | ||
| from dataflow.utils.registry import OPERATOR_REGISTRY | ||
| from dataflow.utils.storage import DataFlowStorage | ||
| from dataflow import get_logger | ||
|
|
||
| @OPERATOR_REGISTRY.register() | ||
| class LLMOutputParser(OperatorABC): | ||
| def __init__(self, | ||
| mode: Literal['question', 'answer'], | ||
| output_dir, | ||
| intermediate_dir: str = "intermediate", | ||
| ): | ||
| self.logger = get_logger() | ||
| self.mode = mode | ||
| self.output_dir = output_dir | ||
| self.intermediate_dir = intermediate_dir | ||
|
|
||
| @staticmethod | ||
| def get_desc(lang: str = "zh") -> str: | ||
| if lang == 'zh': | ||
| return ( | ||
| "LLM输出解析算子。" | ||
| "将LLM生成的包含题目和答案ID的响应文本," | ||
| "转换为结构化的QA列表,并复制相关图片到输出目录。" | ||
| ) | ||
| else: | ||
| return ( | ||
| "LLM output parsing operator." | ||
| "Converts LLM-generated response text containing question and answer IDs" | ||
| "into a structured QA list and copies related images to the output directory." | ||
| ) | ||
|
|
||
| def _id_to_text(self, input_ids, input_json, image_prefix="images"): | ||
| texts = [] | ||
| id_list = input_ids.replace(' ', '').split(',') | ||
| for id in id_list: | ||
| try: | ||
| int(id) | ||
| except: | ||
| continue | ||
| if int(id) < len(input_json): | ||
| try: | ||
| item = input_json[int(id)] | ||
| except: | ||
| continue | ||
| if 'text' in item: | ||
| texts.append(item['text']) | ||
| elif 'img_path' in item: | ||
| try: | ||
| img_path = item.get('img_path', '') | ||
| img_name = os.path.basename(img_path) | ||
| new_path = f"{image_prefix}/{img_name}" | ||
| texts.append(f"") | ||
| except: | ||
| pass | ||
| elif item.get('type','') == 'list': | ||
| if item['sub_type'] == 'text': | ||
| try: | ||
| texts.append(input_json[int(id)]['list_items'].pop(0)) | ||
| except: | ||
| pass | ||
| return '\n'.join(texts) | ||
|
|
||
| def _convert_response(self, input_response, input_json_path, image_prefix="images"): | ||
| qa_list = [] | ||
| with open(input_json_path, 'r') as infile: | ||
| input_json = list(json.load(infile)) | ||
| # 提取title | ||
| for chapter_block in re.findall(r'<chapter>(.*?)</chapter>', input_response, flags=re.DOTALL): | ||
| title = re.search(r'<title>(.*?)</title>', chapter_block, flags=re.DOTALL) | ||
| if title: | ||
| chapter_title = self._id_to_text(title.group(1).strip(), input_json, image_prefix) | ||
| else: | ||
| chapter_title = "" | ||
| # 找出所有 qa_pair 块 | ||
| for pair in re.findall(r'<qa_pair>(.*?)</qa_pair>', chapter_block, flags=re.DOTALL): | ||
| # 提取 question 部分 | ||
| q_match = re.search(r'<question>(.*?)</question>', pair, flags=re.DOTALL) | ||
| # 提取 answer 部分 | ||
| a_match = re.search(r'<answer>(.*?)</answer>', pair, flags=re.DOTALL) | ||
| # 提取solution部分 | ||
| s_match = re.search(r'<solution>(.*?)</solution>', pair, flags=re.DOTALL) | ||
| # 提取label | ||
| label_match = re.search(r'<label>(.*?)</label>', pair, flags=re.DOTALL) | ||
| if not ((q_match and label_match) or (a_match and label_match) or (s_match and label_match)): | ||
| continue | ||
| label = label_match.group(1).strip() | ||
| qa_list.append({ | ||
| 'question': self._id_to_text(q_match.group(1).strip(), input_json, image_prefix) if q_match else "", | ||
| 'answer': a_match.group(1).strip() if a_match else "", | ||
| 'solution': self._id_to_text(s_match.group(1).strip(), input_json, image_prefix) if s_match else "", | ||
| 'label': label, | ||
| 'chapter_title': chapter_title | ||
| }) | ||
| return qa_list | ||
|
|
||
| def run(self, storage: DataFlowStorage, | ||
| input_response_path_key, | ||
| input_converted_layout_path_key, | ||
| input_name_key, | ||
| output_qalist_path_key, | ||
| ): | ||
| dataframe = storage.read("dataframe") | ||
|
|
||
| # Response 转换 | ||
| for idx, row in dataframe.iterrows(): | ||
| converted_json_path = row[input_converted_layout_path_key] | ||
| response = Path(row[input_response_path_key]).read_text(encoding='utf-8') | ||
| name = row[input_name_key] | ||
|
|
||
| image_prefix = os.path.join(name, f"{self.mode}_images") | ||
| qa_list = self._convert_response(response, converted_json_path, image_prefix) | ||
| output_qalist_path = os.path.join(self.output_dir, name, f"extracted_{self.mode}s.jsonl") | ||
| os.makedirs(os.path.dirname(output_qalist_path), exist_ok=True) | ||
| with open(output_qalist_path, 'w') as outfile: | ||
| for qa in qa_list: | ||
| json.dump(qa, outfile, ensure_ascii=False) | ||
| outfile.write('\n') | ||
|
|
||
| # 复制图片 | ||
| src_dir = os.path.join(self.intermediate_dir, 'mineru', Path(converted_json_path).stem).replace('_content_list_converted','') | ||
| src_images = os.path.join(src_dir, 'vlm', 'images') | ||
| dst_images = os.path.join(self.output_dir, image_prefix) | ||
|
|
||
| try: | ||
| if os.path.exists(src_images): | ||
| shutil.copytree(src_images, dst_images) | ||
| else: | ||
| self.logger.warning(f"Source images dir does not exist: {src_images}") | ||
| except Exception as e: | ||
| self.logger.warning(f"Failed to copy images from {src_images} to {dst_images}: {e}") | ||
|
|
||
| dataframe.loc[idx, output_qalist_path_key] = output_qalist_path | ||
|
|
||
| storage.write(dataframe) |
69 changes: 69 additions & 0 deletions
69
dataflow/operators/pdf2vqa/generate/mineru_to_llm_input_operator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| import json | ||
| from dataflow.core import OperatorABC | ||
| from dataflow.utils.registry import OPERATOR_REGISTRY | ||
| from dataflow.utils.storage import DataFlowStorage | ||
|
|
||
| @OPERATOR_REGISTRY.register() | ||
| class MinerU2LLMInputOperator(OperatorABC): | ||
| def __init__(self): | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def get_desc(lang: str = "zh") -> str: | ||
| if lang == 'zh': | ||
| return ( | ||
| "MinerU格式转换为LLM输入格式算子。" | ||
| "将MinerU生成的内容列表JSON文件转换为适合LLM处理的格式," | ||
| "包括展平列表项并重新编号。" | ||
| ) | ||
| else: | ||
| return ( | ||
| "Convert MinerU format to LLM input format operator." | ||
| "Transforms the content list JSON file generated by MinerU into a format suitable for LLM processing," | ||
| "including flattening list items and re-indexing." | ||
| ) | ||
|
|
||
| def _convert_json(self, input_file, output_file): | ||
| with open(input_file, 'r') as infile: | ||
| data = list(json.load(infile)) | ||
|
|
||
| new_data = [] | ||
| id = 0 | ||
| for item in data: | ||
| item['id'] = id | ||
| item.pop('bbox', None) | ||
| item.pop('page_idx', None) | ||
| if item.get('type','') == 'list': | ||
| if item['sub_type'] == 'text': | ||
| for idx, list_item in enumerate(item.get('list_items', [])): | ||
| new_item = { | ||
| 'type': 'text', | ||
| 'text': list_item, | ||
| 'id': id + idx, | ||
| } | ||
| new_data.append(new_item) | ||
| id += len(item.get('list_items', [])) | ||
| else: | ||
| new_data.append(item) | ||
| id += 1 | ||
|
|
||
| with open(output_file, 'w') as outfile: | ||
| json.dump(new_data, outfile, ensure_ascii=False) | ||
|
|
||
| def run(self, storage: DataFlowStorage, | ||
| input_markdown_path_key, | ||
| output_converted_layout_key, | ||
| ): | ||
| dataframe = storage.read("dataframe") | ||
|
|
||
| for index, row in dataframe.iterrows(): | ||
| input_json_path = row[input_markdown_path_key].replace('.md', '_content_list.json') | ||
| converted_path = input_json_path.replace('.json', '_converted.json') | ||
| self._convert_json(input_json_path, converted_path) | ||
| dataframe.at[index, output_converted_layout_key] = converted_path | ||
|
|
||
| with open(converted_path, 'r') as infile: | ||
| data = json.load(infile) | ||
| assert isinstance(data, list), f"Expected list, got {type(data)} for {input_json_path}" | ||
|
|
||
| storage.write(dataframe) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个算子是来转格式,如果作为算子存在,也遵循我们的算子命名规矩吧,比如文件名叫mineru_to_llm_formatter,类名一样但是驼峰