|
5 | 5 | from pathlib import Path |
6 | 6 | from typing import Any, Dict, List, Optional, Union |
7 | 7 |
|
| 8 | +import ray |
| 9 | +from ray.data import Dataset |
| 10 | + |
8 | 11 | from graphgen.bases.base_reader import BaseReader |
9 | 12 | from graphgen.models.reader.txt_reader import TXTReader |
10 | 13 | from graphgen.utils import logger, pick_device |
@@ -62,19 +65,32 @@ def __init__( |
62 | 65 | self.parser = MinerUParser() |
63 | 66 | self.txt_reader = TXTReader() |
64 | 67 |
|
65 | | - def read(self, file_path: str, **override) -> List[Dict[str, Any]]: |
66 | | - """ |
67 | | - file_path |
68 | | - **override: override MinerU parameters |
69 | | - """ |
70 | | - pdf_path = Path(file_path).expanduser().resolve() |
71 | | - if not pdf_path.is_file(): |
72 | | - raise FileNotFoundError(pdf_path) |
| 68 | + def read( |
| 69 | + self, |
| 70 | + input_path: Union[str, List[str]], |
| 71 | + parallelism: int = 4, |
| 72 | + **override, |
| 73 | + ) -> Dataset: |
| 74 | + |
| 75 | + # Ensure input_path is a list |
| 76 | + if isinstance(input_path, str): |
| 77 | + input_path = [input_path] |
73 | 78 |
|
74 | | - kwargs = {**self._default_kwargs, **override} |
| 79 | + paths_ds = ray.data.from_items(input_path) |
75 | 80 |
|
76 | | - mineru_result = self._call_mineru(pdf_path, kwargs) |
77 | | - return self.filter(mineru_result) |
| 81 | + def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: |
| 82 | + try: |
| 83 | + pdf_path = row["item"] |
| 84 | + kwargs = {**self._default_kwargs, **override} |
| 85 | + return self._call_mineru(Path(pdf_path), kwargs) |
| 86 | + except Exception as e: |
| 87 | + logger.error("Failed to process %s: %s", row, e) |
| 88 | + return [] |
| 89 | + |
| 90 | + docs_ds = paths_ds.flat_map(process_pdf) |
| 91 | + docs_ds = docs_ds.filter(self._should_keep_item) |
| 92 | + |
| 93 | + return docs_ds |
78 | 94 |
|
79 | 95 | def _call_mineru( |
80 | 96 | self, pdf_path: Path, kwargs: Dict[str, Any] |
@@ -161,18 +177,18 @@ def _try_load_cached_result( |
161 | 177 |
|
162 | 178 | base = os.path.dirname(json_file) |
163 | 179 | results = [] |
164 | | - for item in data: |
| 180 | + for it in data: |
165 | 181 | for key in ("img_path", "table_img_path", "equation_img_path"): |
166 | | - rel_path = item.get(key) |
| 182 | + rel_path = it.get(key) |
167 | 183 | if rel_path: |
168 | | - item[key] = str(Path(base).joinpath(rel_path).resolve()) |
169 | | - if item["type"] == "text": |
170 | | - item["content"] = item["text"] |
171 | | - del item["text"] |
| 184 | + it[key] = str(Path(base).joinpath(rel_path).resolve()) |
| 185 | + if it["type"] == "text": |
| 186 | + it["content"] = it["text"] |
| 187 | + del it["text"] |
172 | 188 | for key in ("page_idx", "bbox", "text_level"): |
173 | | - if item.get(key) is not None: |
174 | | - del item[key] |
175 | | - results.append(item) |
| 189 | + if it.get(key) is not None: |
| 190 | + del it[key] |
| 191 | + results.append(it) |
176 | 192 | return results |
177 | 193 |
|
178 | 194 | @staticmethod |
@@ -231,3 +247,22 @@ def _check_bin() -> None: |
231 | 247 | "MinerU is not installed or not found in PATH. Please install it from pip: \n" |
232 | 248 | "pip install -U 'mineru[core]'" |
233 | 249 | ) from exc |
| 250 | + |
| 251 | + |
| 252 | +if __name__ == "__main__": |
| 253 | + reader = PDFReader( |
| 254 | + output_dir="./output", |
| 255 | + method="auto", |
| 256 | + backend="pipeline", |
| 257 | + device="cpu", |
| 258 | + lang="en", |
| 259 | + formula=True, |
| 260 | + table=True, |
| 261 | + ) |
| 262 | + dataset = reader.read( |
| 263 | + "/home/PJLAB/chenzihong/Project/graphgen/resources/input_examples/pdf_demo.pdf", |
| 264 | + parallelism=2, |
| 265 | + ) |
| 266 | + |
| 267 | + for item in dataset.take_all(): |
| 268 | + print(item) |
0 commit comments