|
10 | 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
11 | 11 | # See the License for the specific language governing permissions and |
12 | 12 | # limitations under the License. |
13 | | -import json |
| 13 | + |
| 14 | +import asyncio |
14 | 15 | import logging |
15 | 16 | import os |
16 | | -import time |
17 | | -from abc import ABC, abstractmethod |
18 | | -from dataclasses import dataclass |
19 | 17 | from functools import lru_cache |
20 | 18 | from pathlib import Path |
21 | 19 | from typing import Any, Callable, Optional, Union |
22 | | -from urllib.parse import urlparse |
23 | 20 |
|
24 | 21 | from torch.utils.data import Dataset |
25 | 22 |
|
26 | | -from litdata.constants import _ASYNCIO_AVAILABLE, _FSSPEC_AVAILABLE, _TQDM_AVAILABLE, _ZSTD_AVAILABLE |
| 23 | +from litdata.raw.indexer import BaseIndexer, FileIndexer, FileMetadata |
27 | 24 | from litdata.streaming.downloader import Downloader, get_downloader |
28 | 25 | from litdata.streaming.resolver import Dir, _resolve_dir |
29 | 26 | from litdata.utilities.dataset_utilities import generate_md5_hash, get_default_cache_dir |
30 | 27 |
|
31 | | -if not _ASYNCIO_AVAILABLE: |
32 | | - raise ModuleNotFoundError( |
33 | | - "The 'asyncio' package is required for streaming datasets. Please install it with `pip install asyncio`." |
34 | | - ) |
35 | | -else: |
36 | | - import asyncio |
37 | | - |
38 | 28 | logger = logging.getLogger(__name__) |
39 | | -SUPPORTED_PROVIDERS = ("s3", "gs", "azure") |
40 | | - |
41 | | - |
42 | | -@dataclass |
43 | | -class FileMetadata: |
44 | | - """Metadata for a single file in the dataset.""" |
45 | | - |
46 | | - path: str |
47 | | - size: int |
48 | | - |
49 | | - def to_dict(self) -> dict[str, Any]: |
50 | | - return {"path": self.path, "size": self.size} |
51 | | - |
52 | | - @classmethod |
53 | | - def from_dict(cls, data: dict[str, Any]) -> "FileMetadata": |
54 | | - return cls(path=data["path"], size=data["size"]) |
55 | | - |
56 | | - |
57 | | -class BaseIndexer(ABC): |
58 | | - """Abstract base class for file indexing strategies.""" |
59 | | - |
60 | | - @abstractmethod |
61 | | - def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]: |
62 | | - """Discover dataset files and return their metadata.""" |
63 | | - |
64 | | - def build_or_load_index( |
65 | | - self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]] |
66 | | - ) -> list[FileMetadata]: |
67 | | - """Build or load a ZSTD-compressed index of file metadata.""" |
68 | | - if not _ZSTD_AVAILABLE: |
69 | | - raise ModuleNotFoundError(str(_ZSTD_AVAILABLE)) |
70 | | - |
71 | | - import zstd |
72 | | - |
73 | | - index_path = Path(cache_dir) / "index.json.zstd" |
74 | | - |
75 | | - # Try loading cached index if it exists |
76 | | - if index_path.exists(): |
77 | | - try: |
78 | | - with open(index_path, "rb") as f: |
79 | | - compressed_data = f.read() |
80 | | - metadata = json.loads(zstd.decompress(compressed_data).decode("utf-8")) |
81 | | - |
82 | | - return [FileMetadata.from_dict(file_data) for file_data in metadata["files"]] |
83 | | - except (FileNotFoundError, json.JSONDecodeError, zstd.ZstdError, KeyError) as e: |
84 | | - logger.warning(f"Failed to load cached index from {index_path}: {e}") |
85 | | - |
86 | | - # Build fresh index |
87 | | - logger.info(f"Building index for {input_dir} at {index_path}") |
88 | | - files = self.discover_files(input_dir, storage_options) |
89 | | - if not files: |
90 | | - raise ValueError(f"No files found in {input_dir}") |
91 | | - |
92 | | - # Cache the index with ZSTD compression |
93 | | - # TODO: upload the index to cloud storage |
94 | | - try: |
95 | | - metadata = { |
96 | | - "source": input_dir, |
97 | | - "files": [file.to_dict() for file in files], |
98 | | - "created_at": time.time(), |
99 | | - } |
100 | | - with open(index_path, "wb") as f: |
101 | | - f.write(zstd.compress(json.dumps(metadata).encode("utf-8"))) |
102 | | - except (OSError, zstd.ZstdError) as e: |
103 | | - logger.warning(f"Error caching index to {index_path}: {e}") |
104 | | - |
105 | | - logger.info(f"Built index with {len(files)} files from {input_dir} at {index_path}") |
106 | | - return files |
107 | | - |
108 | | - |
109 | | -class FileIndexer(BaseIndexer): |
110 | | - """Indexes files recursively from cloud or local storage with optional extension filtering.""" |
111 | | - |
112 | | - def __init__( |
113 | | - self, |
114 | | - max_depth: int = 5, |
115 | | - extensions: Optional[list[str]] = None, |
116 | | - ): |
117 | | - self.max_depth = max_depth |
118 | | - self.extensions = [ext.lower() for ext in (extensions or [])] |
119 | | - |
120 | | - def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]: |
121 | | - """Discover dataset files and return their metadata.""" |
122 | | - parsed_url = urlparse(input_dir) |
123 | | - |
124 | | - if parsed_url.scheme in SUPPORTED_PROVIDERS: |
125 | | - return self._discover_cloud_files(input_dir, storage_options) |
126 | | - |
127 | | - if not parsed_url.scheme or parsed_url.scheme == "file": |
128 | | - return self._discover_local_files(input_dir) |
129 | | - |
130 | | - raise ValueError( |
131 | | - f"Unsupported input directory scheme: {parsed_url.scheme}. Supported schemes are: {SUPPORTED_PROVIDERS}" |
132 | | - ) |
133 | | - |
134 | | - def _discover_cloud_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]: |
135 | | - """Recursively list files in a cloud storage bucket.""" |
136 | | - if not _FSSPEC_AVAILABLE: |
137 | | - raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE)) |
138 | | - import fsspec |
139 | | - |
140 | | - obj = urlparse(input_dir) |
141 | | - |
142 | | - # TODO: Research on switching to 'obstore' for file listing to potentially improve performance. |
143 | | - # Currently using 'fsspec' due to some issues with 'obstore' when handling multiple instances. |
144 | | - fs = fsspec.filesystem(obj.scheme, **(storage_options or {})) |
145 | | - files = fs.find(input_dir, maxdepth=self.max_depth, detail=True, withdirs=False) |
146 | | - |
147 | | - if _TQDM_AVAILABLE: |
148 | | - from tqdm.auto import tqdm |
149 | | - |
150 | | - pbar = tqdm(desc="Discovering files", total=len(files)) |
151 | | - |
152 | | - metadatas = [] |
153 | | - for _, file_info in files.items(): |
154 | | - if file_info.get("type") != "file": |
155 | | - continue |
156 | | - |
157 | | - file_path = file_info["name"] |
158 | | - if self._should_include_file(file_path): |
159 | | - metadata = FileMetadata( |
160 | | - path=f"{obj.scheme}://{file_path}", |
161 | | - size=file_info.get("size", 0), |
162 | | - ) |
163 | | - metadatas.append(metadata) |
164 | | - if _TQDM_AVAILABLE: |
165 | | - pbar.update(1) |
166 | | - if _TQDM_AVAILABLE: |
167 | | - pbar.close() |
168 | | - return metadatas |
169 | | - |
170 | | - def _discover_local_files(self, input_dir: str) -> list[FileMetadata]: |
171 | | - """Recursively list files in the local filesystem.""" |
172 | | - path = Path(input_dir) |
173 | | - metadatas = [] |
174 | | - |
175 | | - for file_path in path.rglob("*"): |
176 | | - if not file_path.is_file(): |
177 | | - continue |
178 | | - |
179 | | - if self._should_include_file(str(file_path)): |
180 | | - metadata = FileMetadata( |
181 | | - path=str(file_path), |
182 | | - size=file_path.stat().st_size, |
183 | | - ) |
184 | | - metadatas.append(metadata) |
185 | | - |
186 | | - return metadatas |
187 | | - |
188 | | - def _should_include_file(self, file_path: str) -> bool: |
189 | | - """Return True if file matches allowed extensions.""" |
190 | | - file_ext = Path(file_path).suffix.lower() |
191 | | - return not self.extensions or file_ext in self.extensions |
192 | 29 |
|
193 | 30 |
|
194 | 31 | class CacheManager: |
|
0 commit comments