Skip to content

Commit bc487fb

Browse files
feat: add parallel scan_files
1 parent a7a0155 commit bc487fb

File tree

2 files changed

+246
-38
lines changed

2 files changed

+246
-38
lines changed
Lines changed: 76 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
from pathlib import Path
1+
from ray.data.datasource import (
2+
Datasource, ReadTask
3+
)
4+
import pyarrow as pa
5+
from typing import List, Dict, Any, Optional, Union
6+
27
from typing import Iterator, List, Optional
38

9+
import ray
10+
411
from graphgen.models import (
512
CSVReader,
613
JSONLReader,
@@ -34,47 +41,78 @@ def _build_reader(suffix: str, cache_dir: str | None):
3441
return _MAPPING[suffix](output_dir=cache_dir)
3542
return _MAPPING[suffix]()
3643

44+
class UnifiedFileDatasource(Datasource):
45+
pass
46+
3747

3848
def read_files(
39-
input_file: str,
49+
input_path: str,
4050
allowed_suffix: Optional[List[str]] = None,
4151
cache_dir: Optional[str] = None,
42-
) -> Iterator[list[dict]]:
43-
path = Path(input_file).expanduser()
44-
if not path.exists():
45-
raise FileNotFoundError(f"[Read] input_path not found: {input_file}")
52+
parallelism: int = 4,
53+
**ray_kwargs,
54+
) -> ray.data.Dataset:
55+
"""
56+
Reads files from the specified input path, filtering by allowed suffixes,
57+
and returns a Ray Dataset containing the read documents.
58+
:param input_path: input file or directory path
59+
:param allowed_suffix: list of allowed file suffixes (e.g., ['pdf', 'txt'])
60+
:param cache_dir: directory to cache intermediate files (used for PDF reading)
61+
:param parallelism: number of parallel workers for reading files
62+
:param ray_kwargs: additional keyword arguments for Ray Dataset reading
63+
:return: Ray Dataset containing the read documents
64+
"""
65+
66+
if not ray.is_initialized():
67+
ray.init()
68+
4669

47-
if allowed_suffix is None:
48-
support_suffix = set(_MAPPING.keys())
49-
else:
50-
support_suffix = {s.lower().lstrip(".") for s in allowed_suffix}
70+
return ray.data.read_datasource(
71+
UnifiedFileDatasource(
72+
paths=[input_path],
73+
allowed_suffix=allowed_suffix,
74+
cache_dir=cache_dir,
75+
**ray_kwargs, # Pass additional Ray kwargs here
76+
),
77+
parallelism=parallelism,
78+
)
5179

52-
# single file
53-
if path.is_file():
54-
suffix = path.suffix.lstrip(".").lower()
55-
if suffix not in support_suffix:
56-
logger.warning(
57-
"[Read] Skip file %s (suffix '%s' not in allowed_suffix %s)",
58-
path,
59-
suffix,
60-
support_suffix,
61-
)
62-
return
63-
reader = _build_reader(suffix, cache_dir)
64-
logger.info("[Read] Reading file %s", path)
65-
yield reader.read(str(path))
66-
return
6780

68-
# folder
69-
logger.info("[Read] Streaming directory %s", path)
70-
for p in path.rglob("*"):
71-
if p.is_file() and p.suffix.lstrip(".").lower() in support_suffix:
72-
try:
73-
suffix = p.suffix.lstrip(".").lower()
74-
reader = _build_reader(suffix, cache_dir)
75-
logger.info("[Reader] Reading file %s", p)
76-
docs = reader.read(str(p))
77-
if docs:
78-
yield docs
79-
except Exception: # pylint: disable=broad-except
80-
logger.exception("[Reader] Error reading %s", p)
81+
# path = Path(input_file).expanduser()
82+
# if not path.exists():
83+
# raise FileNotFoundError(f"[Read] input_path not found: {input_file}")
84+
#
85+
# if allowed_suffix is None:
86+
# support_suffix = set(_MAPPING.keys())
87+
# else:
88+
# support_suffix = {s.lower().lstrip(".") for s in allowed_suffix}
89+
#
90+
# # single file
91+
# if path.is_file():
92+
# suffix = path.suffix.lstrip(".").lower()
93+
# if suffix not in support_suffix:
94+
# logger.warning(
95+
# "[Read] Skip file %s (suffix '%s' not in allowed_suffix %s)",
96+
# path,
97+
# suffix,
98+
# support_suffix,
99+
# )
100+
# return
101+
# reader = _build_reader(suffix, cache_dir)
102+
# logger.info("[Read] Reading file %s", path)
103+
# yield reader.read(str(path))
104+
# return
105+
#
106+
# # folder
107+
# logger.info("[Read] Streaming directory %s", path)
108+
# for p in path.rglob("*"):
109+
# if p.is_file() and p.suffix.lstrip(".").lower() in support_suffix:
110+
# try:
111+
# suffix = p.suffix.lstrip(".").lower()
112+
# reader = _build_reader(suffix, cache_dir)
113+
# logger.info("[Reader] Reading file %s", p)
114+
# docs = reader.read(str(p))
115+
# if docs:
116+
# yield docs
117+
# except Exception: # pylint: disable=broad-except
118+
# logger.exception("[Reader] Error reading %s", p)
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import os
2+
import time
3+
from typing import List, Dict, Any, Set, Union
4+
from pathlib import Path
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
6+
from diskcache import Cache
7+
from graphgen.utils import logger
8+
9+
class ParallelDirScanner:
10+
def __init__(self,
11+
cache_dir: str,
12+
allowed_suffix,
13+
rescan: bool = False,
14+
max_workers: int = 4
15+
):
16+
self.cache = Cache(cache_dir)
17+
self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None
18+
self.rescan = rescan
19+
self.max_workers = max_workers
20+
21+
def scan(self, paths: Union[str, List[str]], recursive: bool = True) -> Dict[str, Any]:
22+
if isinstance(paths, str):
23+
paths = [paths]
24+
25+
results = {}
26+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
27+
future_to_path = {
28+
executor.submit(self._scan_dir, Path(p).resolve(), recursive, set()): p
29+
for p in paths if os.path.exists(p)
30+
}
31+
32+
for future in as_completed(future_to_path):
33+
path = future_to_path[future]
34+
try:
35+
results[path] = future.result()
36+
except Exception as e:
37+
logger.error("Error scanning path %s: %s", path, e)
38+
results[path] = {'error': str(e), 'files': [], 'dirs': [], 'stats': {}}
39+
40+
return results
41+
42+
def _scan_dir(self, path: Path, recursive: bool, visited: Set[str]) -> Dict[str, Any]:
43+
path_str = str(path)
44+
45+
# Avoid cycles due to symlinks
46+
if path_str in visited:
47+
logger.warning("Skipping already visited path: %s", path_str)
48+
return self._empty_result(path_str)
49+
50+
# cache check
51+
cache_key = f"scan::{path_str}::recursive::{recursive}"
52+
cached = self.cache.get(cache_key)
53+
if cached and not self.rescan:
54+
logger.info("Using cached scan result for path: %s", path_str)
55+
return cached['data']
56+
57+
logger.info("Scanning path: %s", path_str)
58+
files, dirs = [], []
59+
stats = {'total_size': 0, 'file_count': 0, 'dir_count': 0, 'errors': 0}
60+
61+
try:
62+
with os.scandir(path_str) as entries:
63+
for entry in entries:
64+
try:
65+
entry_stat = entry.stat(follow_symlinks=False)
66+
67+
if entry.is_dir():
68+
dirs.append({
69+
'path': entry.path,
70+
'name': entry.name,
71+
'mtime': entry_stat.st_mtime
72+
})
73+
stats['dir_count'] += 1
74+
else:
75+
# allowed suffix filter
76+
if self.allowed_suffix:
77+
suffix = Path(entry.name).suffix.lower()
78+
if suffix not in self.allowed_suffix:
79+
continue
80+
81+
files.append({
82+
'path': entry.path,
83+
'name': entry.name,
84+
'size': entry_stat.st_size,
85+
'mtime': entry_stat.st_mtime
86+
})
87+
stats['total_size'] += entry_stat.st_size
88+
stats['file_count'] += 1
89+
90+
except OSError:
91+
stats['errors'] += 1
92+
93+
except (PermissionError, FileNotFoundError, OSError) as e:
94+
logger.error("Failed to scan directory %s: %s", path_str, e)
95+
return {'error': str(e), 'files': [], 'dirs': [], 'stats': stats}
96+
97+
if recursive:
98+
sub_visited = visited | {path_str}
99+
sub_results = self._scan_subdirs(dirs, sub_visited)
100+
101+
for sub_data in sub_results.values():
102+
files.extend(sub_data.get('files', []))
103+
stats['total_size'] += sub_data['stats'].get('total_size', 0)
104+
stats['file_count'] += sub_data['stats'].get('file_count', 0)
105+
106+
result = {'path': path_str, 'files': files, 'dirs': dirs, 'stats': stats}
107+
self._cache_result(cache_key, result, path)
108+
return result
109+
110+
def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, Any]:
111+
"""
112+
Parallel scan subdirectories
113+
:param dir_list
114+
:param visited
115+
:return:
116+
"""
117+
results = {}
118+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
119+
futures = {
120+
executor.submit(self._scan_dir, Path(d['path']), True, visited): d['path']
121+
for d in dir_list
122+
}
123+
124+
for future in as_completed(futures):
125+
path = futures[future]
126+
try:
127+
results[path] = future.result()
128+
except Exception as e:
129+
logger.error("Error scanning subdirectory %s: %s", path, e)
130+
results[path] = {'error': str(e), 'files': [], 'dirs': [], 'stats': {}}
131+
132+
return results
133+
134+
def _cache_result(self, key: str, result: Dict, path: Path):
135+
"""Cache the scan result"""
136+
try:
137+
self.cache.set(key, {
138+
'data': result,
139+
'dir_mtime': path.stat().st_mtime,
140+
'cached_at': time.time()
141+
})
142+
logger.info(f"Cached scan result for: {path}")
143+
except OSError:
144+
pass
145+
146+
def invalidate(self, path: str):
147+
"""Invalidate cache for a specific path"""
148+
path = Path(path).resolve()
149+
keys = [k for k in self.cache if k.startswith(f"scan:{path}")]
150+
for k in keys:
151+
self.cache.delete(k)
152+
logger.info(f"Invalidated cache for path: {path}")
153+
154+
def close(self):
155+
self.cache.close()
156+
157+
def __enter__(self):
158+
return self
159+
160+
def __exit__(self, *args):
161+
self.close()
162+
163+
@staticmethod
164+
def _empty_result(path: str) -> Dict[str, Any]:
165+
return {
166+
'path': path,
167+
'files': [],
168+
'dirs': [],
169+
'stats': {'total_size': 0, 'file_count': 0, 'dir_count': 0, 'errors': 0}
170+
}

0 commit comments

Comments
 (0)