Skip to content

Commit 0422bd0

Browse files
refactor: refactor csv_reader using ray data
1 parent bd2f7c4 commit 0422bd0

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

graphgen/bases/base_reader.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import ABC, abstractmethod
33
from typing import Any, Dict, List, Union
44

5+
import pandas as pd
56
import requests
67
from ray.data import Dataset
78

@@ -43,6 +44,21 @@ def _should_keep_item(self, item: Dict[str, Any]) -> bool:
4344
return bool(content)
4445
return True
4546

47+
def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame:
48+
"""
49+
Validate data format.
50+
"""
51+
if "type" not in batch.columns:
52+
raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")
53+
54+
if "text" in batch["type"].values:
55+
if self.text_column not in batch.columns:
56+
raise ValueError(
57+
f"Missing '{self.text_column}' column for text documents"
58+
)
59+
60+
return batch
61+
4662
@staticmethod
4763
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
4864
"""
Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Any, Dict, List
1+
from typing import List, Union
22

3-
import pandas as pd
3+
import ray
4+
from ray.data import Dataset
45

56
from graphgen.bases.base_reader import BaseReader
67

@@ -13,13 +14,23 @@ class CSVReader(BaseReader):
1314
- if type is "text", "content" column must be present.
1415
"""
1516

16-
def read(self, file_path: str) -> List[Dict[str, Any]]:
17+
def read(
18+
self,
19+
input_path: Union[str, List[str]],
20+
override_num_blocks: int = None,
21+
) -> Dataset:
22+
"""
23+
Read CSV files and return Ray Dataset.
1724
18-
df = pd.read_csv(file_path)
19-
for _, row in df.iterrows():
20-
assert "type" in row, f"Missing 'type' column in document: {row.to_dict()}"
21-
if row["type"] == "text" and self.text_column not in row:
22-
raise ValueError(
23-
f"Missing '{self.text_column}' in document: {row.to_dict()}"
24-
)
25-
return self.filter(df.to_dict(orient="records"))
25+
:param input_path: Path to CSV file or list of CSV files.
26+
:param override_num_blocks: Number of blocks for Ray Dataset reading.
27+
:return: Ray Dataset containing validated and filtered data.
28+
"""
29+
30+
ds = ray.data.read_csv(input_path, override_num_blocks=override_num_blocks)
31+
32+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
33+
34+
ds = ds.filter(self._should_keep_item)
35+
36+
return ds

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ requests
2020
fastapi
2121
trafilatura
2222
aiohttp
23+
ray
24+
diskcache
2325

2426
leidenalg
2527
igraph

0 commit comments

Comments
 (0)