Skip to content

Commit bd2f7c4

Browse files
refactor: refactor txt_reader using ray data
1 parent bc487fb commit bd2f7c4

File tree

2 files changed

+68
-46
lines changed

2 files changed

+68
-46
lines changed

graphgen/bases/base_reader.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
from abc import ABC, abstractmethod
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Union
44

55
import requests
6+
from ray.data import Dataset
67

78

89
class BaseReader(ABC):
@@ -14,52 +15,50 @@ def __init__(self, text_column: str = "content"):
1415
self.text_column = text_column
1516

1617
@abstractmethod
17-
def read(self, file_path: str) -> List[Dict[str, Any]]:
18+
def read(self, input_path: Union[str, List[str]]) -> Dataset:
1819
"""
1920
Read data from the specified file path.
2021
21-
:param file_path: Path to the input file.
22-
:return: List of dictionaries containing the data.
22+
:param input_path: Path to the input file or list of file paths.
23+
:return: Ray Dataset containing the read data.
2324
"""
2425

25-
@staticmethod
26-
def filter(data: List[dict]) -> List[dict]:
26+
def _should_keep_item(self, item: Dict[str, Any]) -> bool:
2727
"""
28-
Filter out entries with empty or missing text in the specified column.
28+
Determine whether to keep the given item based on the text column.
2929
30-
:param data: List of dictionaries containing the data.
31-
:return: Filtered list of dictionaries.
30+
:param item: Dictionary representing a data entry.
31+
:return: True if the item should be kept, False otherwise.
3232
"""
33+
item_type = item.get("type")
34+
assert item_type in [
35+
"text",
36+
"image",
37+
"table",
38+
"equation",
39+
"protein",
40+
], f"Unsupported item type: {item_type}"
41+
if item_type == "text":
42+
content = item.get(self.text_column, "").strip()
43+
return bool(content)
44+
return True
3345

34-
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
35-
"""
36-
Check if an image exists at the given local path or URL.
37-
:param path_or_url: Local file path or remote URL of the image.
38-
:param timeout: Timeout for remote URL requests in seconds.
39-
:return: True if the image exists, False otherwise.
40-
"""
41-
if not path_or_url:
42-
return False
43-
if not path_or_url.startswith(("http://", "https://", "ftp://")):
44-
path = path_or_url.replace("file://", "", 1)
45-
path = os.path.abspath(path)
46-
return os.path.isfile(path)
47-
try:
48-
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
49-
return resp.status_code == 200
50-
except requests.RequestException:
51-
return False
52-
53-
filtered_data = []
54-
for item in data:
55-
if item.get("type") == "text":
56-
content = item.get("content", "").strip()
57-
if content:
58-
filtered_data.append(item)
59-
elif item.get("type") in ("image", "table", "equation"):
60-
img_path = item.get("img_path")
61-
if _image_exists(img_path):
62-
filtered_data.append(item)
63-
else:
64-
filtered_data.append(item)
65-
return filtered_data
46+
@staticmethod
47+
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
48+
"""
49+
Check if an image exists at the given local path or URL.
50+
:param path_or_url: Local file path or remote URL of the image.
51+
:param timeout: Timeout for remote URL requests in seconds.
52+
:return: True if the image exists, False otherwise.
53+
"""
54+
if not path_or_url:
55+
return False
56+
if not path_or_url.startswith(("http://", "https://", "ftp://")):
57+
path = path_or_url.replace("file://", "", 1)
58+
path = os.path.abspath(path)
59+
return os.path.isfile(path)
60+
try:
61+
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
62+
return resp.status_code == 200
63+
except requests.RequestException:
64+
return False
Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
1-
from typing import Any, Dict, List
1+
from typing import List, Union
2+
3+
import ray
4+
from ray.data import Dataset
25

36
from graphgen.bases.base_reader import BaseReader
47

58

69
class TXTReader(BaseReader):
7-
def read(self, file_path: str) -> List[Dict[str, Any]]:
8-
with open(file_path, "r", encoding="utf-8") as f:
9-
docs = [{"type": "text", self.text_column: f.read()}]
10-
return self.filter(docs)
10+
def read(
11+
self,
12+
input_path: Union[str, List[str]],
13+
override_num_blocks: int = 4,
14+
) -> Dataset:
15+
"""
16+
Read text files from the specified input path.
17+
:param input_path: Path to the input text file or list of text files.
18+
:param override_num_blocks: Number of blocks to override for Ray Dataset reading.
19+
:return: Ray Dataset containing the read text data.
20+
"""
21+
docs_ds = ray.data.read_text(
22+
input_path, encoding="utf-8", override_num_blocks=override_num_blocks
23+
)
24+
25+
docs_ds = docs_ds.map(
26+
lambda row: {
27+
"type": "text",
28+
self.text_column: row["text"],
29+
}
30+
)
31+
32+
docs_ds = docs_ds.filter(self._should_keep_item)
33+
return docs_ds

0 commit comments

Comments
 (0)