Skip to content

Commit 97f7e75

Browse files
refactor: refactor pdf_reader using ray data
1 parent db8252c commit 97f7e75

File tree

3 files changed

+55
-22
lines changed

3 files changed

+55
-22
lines changed

graphgen/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from .reader import (
2020
CSVReader,
21-
JSONLReader,
2221
JSONReader,
2322
ParquetReader,
2423
PDFReader,

graphgen/models/reader/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .csv_reader import CSVReader
22
from .json_reader import JSONReader
3-
from .jsonl_reader import JSONLReader
43
from .parquet_reader import ParquetReader
54
from .pdf_reader import PDFReader
65
from .pickle_reader import PickleReader

graphgen/models/reader/pdf_reader.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from pathlib import Path
66
from typing import Any, Dict, List, Optional, Union
77

8+
import ray
9+
from ray.data import Dataset
10+
811
from graphgen.bases.base_reader import BaseReader
912
from graphgen.models.reader.txt_reader import TXTReader
1013
from graphgen.utils import logger, pick_device
@@ -62,19 +65,32 @@ def __init__(
6265
self.parser = MinerUParser()
6366
self.txt_reader = TXTReader()
6467

65-
def read(self, file_path: str, **override) -> List[Dict[str, Any]]:
66-
"""
67-
file_path
68-
**override: override MinerU parameters
69-
"""
70-
pdf_path = Path(file_path).expanduser().resolve()
71-
if not pdf_path.is_file():
72-
raise FileNotFoundError(pdf_path)
68+
def read(
69+
self,
70+
input_path: Union[str, List[str]],
71+
parallelism: int = 4,
72+
**override,
73+
) -> Dataset:
74+
75+
# Ensure input_path is a list
76+
if isinstance(input_path, str):
77+
input_path = [input_path]
7378

74-
kwargs = {**self._default_kwargs, **override}
79+
paths_ds = ray.data.from_items(input_path)
7580

76-
mineru_result = self._call_mineru(pdf_path, kwargs)
77-
return self.filter(mineru_result)
81+
def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]:
82+
try:
83+
pdf_path = row["item"]
84+
kwargs = {**self._default_kwargs, **override}
85+
return self._call_mineru(Path(pdf_path), kwargs)
86+
except Exception as e:
87+
logger.error("Failed to process %s: %s", row, e)
88+
return []
89+
90+
docs_ds = paths_ds.flat_map(process_pdf)
91+
docs_ds = docs_ds.filter(self._should_keep_item)
92+
93+
return docs_ds
7894

7995
def _call_mineru(
8096
self, pdf_path: Path, kwargs: Dict[str, Any]
@@ -161,18 +177,18 @@ def _try_load_cached_result(
161177

162178
base = os.path.dirname(json_file)
163179
results = []
164-
for item in data:
180+
for it in data:
165181
for key in ("img_path", "table_img_path", "equation_img_path"):
166-
rel_path = item.get(key)
182+
rel_path = it.get(key)
167183
if rel_path:
168-
item[key] = str(Path(base).joinpath(rel_path).resolve())
169-
if item["type"] == "text":
170-
item["content"] = item["text"]
171-
del item["text"]
184+
it[key] = str(Path(base).joinpath(rel_path).resolve())
185+
if it["type"] == "text":
186+
it["content"] = it["text"]
187+
del it["text"]
172188
for key in ("page_idx", "bbox", "text_level"):
173-
if item.get(key) is not None:
174-
del item[key]
175-
results.append(item)
189+
if it.get(key) is not None:
190+
del it[key]
191+
results.append(it)
176192
return results
177193

178194
@staticmethod
@@ -231,3 +247,22 @@ def _check_bin() -> None:
231247
"MinerU is not installed or not found in PATH. Please install it from pip: \n"
232248
"pip install -U 'mineru[core]'"
233249
) from exc
250+
251+
252+
if __name__ == "__main__":
253+
reader = PDFReader(
254+
output_dir="./output",
255+
method="auto",
256+
backend="pipeline",
257+
device="cpu",
258+
lang="en",
259+
formula=True,
260+
table=True,
261+
)
262+
dataset = reader.read(
263+
"/home/PJLAB/chenzihong/Project/graphgen/resources/input_examples/pdf_demo.pdf",
264+
parallelism=2,
265+
)
266+
267+
for item in dataset.take_all():
268+
print(item)

0 commit comments

Comments
 (0)