Skip to content

Commit db8252c

Browse files
refactor: refactor parquet_reader using ray data
1 parent 36e80ef commit db8252c

File tree

3 files changed

+23
-42
lines changed

3 files changed

+23
-42
lines changed

graphgen/models/reader/json_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class JSONReader(BaseReader):
1010
"""
11-
Reader for JSON files.
11+
Reader for JSON and JSONL files.
1212
Columns:
1313
- type: The type of the document (e.g., "text", "image", etc.)
1414
- if type is "text", "content" column must be present.
@@ -21,7 +21,7 @@ def read(
2121
) -> Dataset:
2222
"""
2323
Read JSON file and return Ray Dataset.
24-
:param input_path: Path to JSON file or list of JSON files.
24+
:param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
2525
:param parallelism: Number of parallel workers for reading files.
2626
:return: Ray Dataset containing validated and filtered data.
2727
"""

graphgen/models/reader/jsonl_reader.py

Lines changed: 0 additions & 30 deletions
This file was deleted.
Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Any, Dict, List
1+
from typing import List, Union
22

3-
import pandas as pd
3+
import ray
4+
from ray.data import Dataset
45

56
from graphgen.bases.base_reader import BaseReader
67

@@ -13,12 +14,22 @@ class ParquetReader(BaseReader):
1314
- if type is "text", "content" column must be present.
1415
"""
1516

16-
def read(self, file_path: str) -> List[Dict[str, Any]]:
17-
df = pd.read_parquet(file_path)
18-
data: List[Dict[str, Any]] = df.to_dict(orient="records")
17+
def read(
18+
self,
19+
input_path: Union[str, List[str]],
20+
override_num_blocks: int = None,
21+
) -> Dataset:
22+
"""
23+
Read Parquet files using Ray Data.
1924
20-
for doc in data:
21-
assert "type" in doc, f"Missing 'type' in document: {doc}"
22-
if doc.get("type") == "text" and self.text_column not in doc:
23-
raise ValueError(f"Missing '{self.text_column}' in document: {doc}")
24-
return self.filter(data)
25+
:param input_path: Path to Parquet file or list of Parquet files.
26+
:param override_num_blocks: Number of blocks for Ray Dataset reading.
27+
:return: Ray Dataset containing validated documents.
28+
"""
29+
if not ray.is_initialized():
30+
ray.init()
31+
32+
ds = ray.data.read_parquet(input_path, override_num_blocks=override_num_blocks)
33+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
34+
ds = ds.filter(self._should_keep_item)
35+
return ds

0 commit comments

Comments
 (0)