Skip to content

Commit e390deb

Browse files
ref: Move raw dataset code to litdata/raw, expose StreamingRawDataset at top-level (Lightning-AI#671)
* ref(constants): remove `asyncio` from requirement cache, already in the std lib * ref: organize raw dataset files to separate folder * ref: update imports and enhance __all__ exports in __init__.py * ref: clarify comments for cloud and local file discovery in FileIndexer * ref: move tests for raw dataet to `raw` folder with respective files * refactor: update import statements for StreamingRawDataset and FileMetadata in usage examples * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add init files * expose classes at module level --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5a57d85 commit e390deb

File tree

10 files changed

+415
-364
lines changed

10 files changed

+415
-364
lines changed

README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ pip install "litdata[extra]" gcsfs
221221

222222
**Usage Example:**
223223
```python
224-
from litdata.streaming.raw_dataset import StreamingRawDataset
225224
from torch.utils.data import DataLoader
225+
from litdata import StreamingRawDataset
226226

227227
dataset = StreamingRawDataset("s3://bucket/files/")
228228

@@ -239,18 +239,19 @@ for batch in loader:
239239
You can also customize how files are grouped by subclassing `StreamingRawDataset` and overriding the `setup` method. This is useful for pairing related files (e.g., image and mask, audio and transcript) or any custom grouping logic.
240240

241241
```python
242-
from litdata.streaming.raw_dataset import StreamingRawDataset, FileMetadata
243-
from torch.utils.data import DataLoader
244242
from typing import Union
243+
from torch.utils.data import DataLoader
244+
from litdata import StreamingRawDataset
245+
from litdata.raw.indexer import FileMetadata
245246

246247
class SegmentationRawDataset(StreamingRawDataset):
247-
def setup(self, files: list[FileMetadata]) -> Union[list[FileMetadata], list[list[FileMetadata]]]:
248-
# TODO: Implement your custom grouping logic here.
249-
# For example, group files by prefix, extension, or any rule you need.
250-
# Return a list of groups, where each group is a list of FileMetadata.
251-
# Example:
252-
# return [[image, mask], ...]
253-
pass
248+
def setup(self, files: list[FileMetadata]) -> Union[list[FileMetadata], list[list[FileMetadata]]]:
249+
# TODO: Implement your custom grouping logic here.
250+
# For example, group files by prefix, extension, or any rule you need.
251+
# Return a list of groups, where each group is a list of FileMetadata.
252+
# Example:
253+
# return [[image, mask], ...]
254+
pass
254255

255256
# Initialize the custom dataset
256257
dataset = SegmentationRawDataset("s3://bucket/files/")

src/litdata/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# limitations under the License.
1313
import warnings
1414

15-
from lightning_utilities.core.imports import RequirementCache
16-
1715
from litdata.__about__ import * # noqa: F403
16+
from litdata.constants import _LIGHTNING_SDK_AVAILABLE
1817
from litdata.processing.functions import map, merge_datasets, optimize, walk
18+
from litdata.raw.dataset import StreamingRawDataset
1919
from litdata.streaming.combined import CombinedStreamingDataset
2020
from litdata.streaming.dataloader import StreamingDataLoader
2121
from litdata.streaming.dataset import StreamingDataset
@@ -32,9 +32,9 @@
3232
category=UserWarning,
3333
)
3434

35-
3635
__all__ = [
3736
"StreamingDataset",
37+
"StreamingRawDataset",
3838
"CombinedStreamingDataset",
3939
"StreamingDataLoader",
4040
"TokensLoader",
@@ -48,7 +48,8 @@
4848
"index_hf_dataset",
4949
"breakpoint",
5050
]
51-
if RequirementCache("lightning_sdk"):
51+
52+
if _LIGHTNING_SDK_AVAILABLE:
5253
from lightning_sdk import Machine # noqa: F401
5354

54-
__all__ + ["Machine"]
55+
__all__.append("Machine")

src/litdata/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
4646
_AV_AVAILABLE = RequirementCache("av")
4747
_OBSTORE_AVAILABLE = RequirementCache("obstore")
48-
_ASYNCIO_AVAILABLE = RequirementCache("asyncio")
4948

5049
_DEBUG = bool(int(os.getenv("DEBUG_LITDATA", "0")))
5150
_PRINT_DEBUG_LOGS = bool(int(os.getenv("PRINT_DEBUG_LOGS", "0")))

src/litdata/raw/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from litdata.raw.dataset import StreamingRawDataset
15+
from litdata.raw.indexer import FileMetadata
16+
17+
__all__ = ["FileMetadata", "StreamingRawDataset"]
Lines changed: 3 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -10,185 +10,22 @@
1010
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
13-
import json
13+
14+
import asyncio
1415
import logging
1516
import os
16-
import time
17-
from abc import ABC, abstractmethod
18-
from dataclasses import dataclass
1917
from functools import lru_cache
2018
from pathlib import Path
2119
from typing import Any, Callable, Optional, Union
22-
from urllib.parse import urlparse
2320

2421
from torch.utils.data import Dataset
2522

26-
from litdata.constants import _ASYNCIO_AVAILABLE, _FSSPEC_AVAILABLE, _TQDM_AVAILABLE, _ZSTD_AVAILABLE
23+
from litdata.raw.indexer import BaseIndexer, FileIndexer, FileMetadata
2724
from litdata.streaming.downloader import Downloader, get_downloader
2825
from litdata.streaming.resolver import Dir, _resolve_dir
2926
from litdata.utilities.dataset_utilities import generate_md5_hash, get_default_cache_dir
3027

31-
if not _ASYNCIO_AVAILABLE:
32-
raise ModuleNotFoundError(
33-
"The 'asyncio' package is required for streaming datasets. Please install it with `pip install asyncio`."
34-
)
35-
else:
36-
import asyncio
37-
3828
logger = logging.getLogger(__name__)
39-
SUPPORTED_PROVIDERS = ("s3", "gs", "azure")
40-
41-
42-
@dataclass
43-
class FileMetadata:
44-
"""Metadata for a single file in the dataset."""
45-
46-
path: str
47-
size: int
48-
49-
def to_dict(self) -> dict[str, Any]:
50-
return {"path": self.path, "size": self.size}
51-
52-
@classmethod
53-
def from_dict(cls, data: dict[str, Any]) -> "FileMetadata":
54-
return cls(path=data["path"], size=data["size"])
55-
56-
57-
class BaseIndexer(ABC):
58-
"""Abstract base class for file indexing strategies."""
59-
60-
@abstractmethod
61-
def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]:
62-
"""Discover dataset files and return their metadata."""
63-
64-
def build_or_load_index(
65-
self, input_dir: str, cache_dir: str, storage_options: Optional[dict[str, Any]]
66-
) -> list[FileMetadata]:
67-
"""Build or load a ZSTD-compressed index of file metadata."""
68-
if not _ZSTD_AVAILABLE:
69-
raise ModuleNotFoundError(str(_ZSTD_AVAILABLE))
70-
71-
import zstd
72-
73-
index_path = Path(cache_dir) / "index.json.zstd"
74-
75-
# Try loading cached index if it exists
76-
if index_path.exists():
77-
try:
78-
with open(index_path, "rb") as f:
79-
compressed_data = f.read()
80-
metadata = json.loads(zstd.decompress(compressed_data).decode("utf-8"))
81-
82-
return [FileMetadata.from_dict(file_data) for file_data in metadata["files"]]
83-
except (FileNotFoundError, json.JSONDecodeError, zstd.ZstdError, KeyError) as e:
84-
logger.warning(f"Failed to load cached index from {index_path}: {e}")
85-
86-
# Build fresh index
87-
logger.info(f"Building index for {input_dir} at {index_path}")
88-
files = self.discover_files(input_dir, storage_options)
89-
if not files:
90-
raise ValueError(f"No files found in {input_dir}")
91-
92-
# Cache the index with ZSTD compression
93-
# TODO: upload the index to cloud storage
94-
try:
95-
metadata = {
96-
"source": input_dir,
97-
"files": [file.to_dict() for file in files],
98-
"created_at": time.time(),
99-
}
100-
with open(index_path, "wb") as f:
101-
f.write(zstd.compress(json.dumps(metadata).encode("utf-8")))
102-
except (OSError, zstd.ZstdError) as e:
103-
logger.warning(f"Error caching index to {index_path}: {e}")
104-
105-
logger.info(f"Built index with {len(files)} files from {input_dir} at {index_path}")
106-
return files
107-
108-
109-
class FileIndexer(BaseIndexer):
110-
"""Indexes files recursively from cloud or local storage with optional extension filtering."""
111-
112-
def __init__(
113-
self,
114-
max_depth: int = 5,
115-
extensions: Optional[list[str]] = None,
116-
):
117-
self.max_depth = max_depth
118-
self.extensions = [ext.lower() for ext in (extensions or [])]
119-
120-
def discover_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]:
121-
"""Discover dataset files and return their metadata."""
122-
parsed_url = urlparse(input_dir)
123-
124-
if parsed_url.scheme in SUPPORTED_PROVIDERS:
125-
return self._discover_cloud_files(input_dir, storage_options)
126-
127-
if not parsed_url.scheme or parsed_url.scheme == "file":
128-
return self._discover_local_files(input_dir)
129-
130-
raise ValueError(
131-
f"Unsupported input directory scheme: {parsed_url.scheme}. Supported schemes are: {SUPPORTED_PROVIDERS}"
132-
)
133-
134-
def _discover_cloud_files(self, input_dir: str, storage_options: Optional[dict[str, Any]]) -> list[FileMetadata]:
135-
"""Recursively list files in a cloud storage bucket."""
136-
if not _FSSPEC_AVAILABLE:
137-
raise ModuleNotFoundError(str(_FSSPEC_AVAILABLE))
138-
import fsspec
139-
140-
obj = urlparse(input_dir)
141-
142-
# TODO: Research on switching to 'obstore' for file listing to potentially improve performance.
143-
# Currently using 'fsspec' due to some issues with 'obstore' when handling multiple instances.
144-
fs = fsspec.filesystem(obj.scheme, **(storage_options or {}))
145-
files = fs.find(input_dir, maxdepth=self.max_depth, detail=True, withdirs=False)
146-
147-
if _TQDM_AVAILABLE:
148-
from tqdm.auto import tqdm
149-
150-
pbar = tqdm(desc="Discovering files", total=len(files))
151-
152-
metadatas = []
153-
for _, file_info in files.items():
154-
if file_info.get("type") != "file":
155-
continue
156-
157-
file_path = file_info["name"]
158-
if self._should_include_file(file_path):
159-
metadata = FileMetadata(
160-
path=f"{obj.scheme}://{file_path}",
161-
size=file_info.get("size", 0),
162-
)
163-
metadatas.append(metadata)
164-
if _TQDM_AVAILABLE:
165-
pbar.update(1)
166-
if _TQDM_AVAILABLE:
167-
pbar.close()
168-
return metadatas
169-
170-
def _discover_local_files(self, input_dir: str) -> list[FileMetadata]:
171-
"""Recursively list files in the local filesystem."""
172-
path = Path(input_dir)
173-
metadatas = []
174-
175-
for file_path in path.rglob("*"):
176-
if not file_path.is_file():
177-
continue
178-
179-
if self._should_include_file(str(file_path)):
180-
metadata = FileMetadata(
181-
path=str(file_path),
182-
size=file_path.stat().st_size,
183-
)
184-
metadatas.append(metadata)
185-
186-
return metadatas
187-
188-
def _should_include_file(self, file_path: str) -> bool:
189-
"""Return True if file matches allowed extensions."""
190-
file_ext = Path(file_path).suffix.lower()
191-
return not self.extensions or file_ext in self.extensions
19229

19330

19431
class CacheManager:

0 commit comments

Comments
 (0)