Skip to content

Commit 97d71ab

Browse files
authored
Data Processor: Resolve several bugs found while publishing a Studio (#19309)
1 parent 93c1ab0 commit 97d71ab

File tree

14 files changed

+853
-89
lines changed

14 files changed

+853
-89
lines changed

requirements/app/app.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud == 0.5.59 # Must be pinned to ensure compatibility
1+
lightning-cloud == 0.5.61 # Must be pinned to ensure compatibility
22
packaging
33
typing-extensions >=4.4.0, <4.8.0
44
deepdiff >=5.7.0, <6.6.0

src/lightning/data/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from lightning.data.streaming.combined import CombinedStreamingDataset
2+
from lightning.data.streaming.dataloader import StreamingDataLoader
13
from lightning.data.streaming.dataset import StreamingDataset
24
from lightning.data.streaming.functions import map, optimize
35

46
__all__ = [
57
"LightningDataset",
68
"StreamingDataset",
9+
"CombinedStreamingDataset",
10+
"StreamingDataLoader",
711
"LightningIterableDataset",
812
"map",
913
"optimize",

src/lightning/data/streaming/cache.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import logging
1515
import os
16-
from dataclasses import dataclass
1716
from typing import Any, Dict, List, Optional, Tuple, Union
1817

1918
from lightning.data.streaming.constants import (
@@ -23,6 +22,7 @@
2322
)
2423
from lightning.data.streaming.item_loader import BaseItemLoader
2524
from lightning.data.streaming.reader import BinaryReader
25+
from lightning.data.streaming.resolver import Dir, _resolve_dir
2626
from lightning.data.streaming.sampler import ChunkedIndex
2727
from lightning.data.streaming.serializers import Serializer
2828
from lightning.data.streaming.writer import BinaryWriter
@@ -31,17 +31,6 @@
3131

3232
logger = logging.Logger(__name__)
3333

34-
if _LIGHTNING_CLOUD_LATEST:
35-
from lightning_cloud.resolver import _resolve_dir
36-
37-
38-
@dataclass
39-
class Dir:
40-
"""Holds a directory path and possibly its associated remote URL."""
41-
42-
path: str
43-
url: Optional[str] = None
44-
4534

4635
class Cache:
4736
def __init__(
@@ -76,6 +65,7 @@ def __init__(
7665

7766
input_dir = _resolve_dir(input_dir)
7867
self._cache_dir = input_dir.path
68+
assert self._cache_dir
7969
self._writer = BinaryWriter(
8070
self._cache_dir,
8171
chunk_size=chunk_size,
@@ -108,15 +98,18 @@ def filled(self) -> bool:
10898
"""Returns whether the caching phase is done."""
10999
if self._is_done:
110100
return True
101+
assert self._cache_dir
111102
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
112103
return self._is_done
113104

114105
@property
115106
def cache_dir(self) -> str:
107+
assert self._cache_dir
116108
return self._cache_dir
117109

118110
@property
119111
def checkpoint_dir(self) -> str:
112+
assert self._cache_dir
120113
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
121114
return self._try_create(checkpoint_dir)
122115

src/lightning/data/streaming/combined.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def __init__(
4444

4545
self._iterator: Optional[_CombinedDatasetIterator] = None
4646

47+
def __len__(self) -> int:
48+
assert self._weights
49+
return int(sum(w * len(d) for w, d in zip(self._weights, self._datasets)))
50+
4751
def __iter__(self) -> Iterator[Any]:
4852
assert self._weights
4953
self._iterator = _CombinedDatasetIterator(self._datasets, self._seed, self._weights)

src/lightning/data/streaming/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# This is required for full pytree serialization / deserialization support
2727
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
2828
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
29-
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.59")
29+
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.61")
3030
_BOTO3_AVAILABLE = RequirementCache("boto3")
3131

3232
# DON'T CHANGE ORDER

src/lightning/data/streaming/data_processor.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import concurrent
12
import json
23
import logging
34
import os
@@ -27,6 +28,7 @@
2728
_LIGHTNING_CLOUD_LATEST,
2829
_TORCH_GREATER_EQUAL_2_1_0,
2930
)
31+
from lightning.data.streaming.resolver import _resolve_dir
3032
from lightning.data.utilities.broadcast import broadcast_object
3133
from lightning.data.utilities.packing import _pack_greedily
3234

@@ -35,7 +37,6 @@
3537

3638
if _LIGHTNING_CLOUD_LATEST:
3739
from lightning_cloud.openapi import V1DatasetType
38-
from lightning_cloud.resolver import _resolve_dir
3940
from lightning_cloud.utils.dataset import _create_dataset
4041

4142

@@ -120,7 +121,9 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
120121
index, paths = r
121122

122123
# 5. Check whether all the files are already downloaded
123-
if all(os.path.exists(p.replace(input_dir.path, cache_dir) if input_dir else p) for p in paths):
124+
if input_dir.path and all(
125+
os.path.exists(p.replace(input_dir.path, cache_dir) if input_dir else p) for p in paths
126+
):
124127
queue_out.put(index)
125128
continue
126129

@@ -131,9 +134,10 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
131134

132135
# 7. Download all the required paths to unblock the current index
133136
for path in paths:
134-
local_path = path.replace(input_dir.path, cache_dir)
137+
if input_dir.path:
138+
local_path = path.replace(input_dir.path, cache_dir)
135139

136-
if input_dir.url:
140+
if input_dir.url and input_dir.path:
137141
path = path.replace(input_dir.path, input_dir.url)
138142

139143
obj = parse.urlparse(path)
@@ -168,7 +172,7 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None:
168172
# 3. Iterate through the paths and delete them sequentially.
169173
for path in paths:
170174
if input_dir:
171-
if not path.startswith(cache_dir):
175+
if not path.startswith(cache_dir) and input_dir.path is not None:
172176
path = path.replace(input_dir.path, cache_dir)
173177

174178
if os.path.exists(path):
@@ -199,11 +203,13 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
199203
if obj.scheme == "s3":
200204
try:
201205
s3.client.upload_file(
202-
local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath))
206+
local_filepath,
207+
obj.netloc,
208+
os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath)),
203209
)
204210
except Exception as e:
205211
print(e)
206-
elif os.path.isdir(output_dir.path):
212+
elif output_dir.path and os.path.isdir(output_dir.path):
207213
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
208214
else:
209215
raise ValueError(f"The provided {output_dir.path} isn't supported.")
@@ -254,20 +260,30 @@ def _map_items_to_workers_weighted(
254260
return [worker_items[worker_id] for worker_id in worker_ids_this_node]
255261

256262

263+
def _get_num_bytes(item: Any, base_path: str) -> int:
264+
flattened_item, _ = tree_flatten(item)
265+
266+
num_bytes = 0
267+
for element in flattened_item:
268+
if isinstance(element, str) and element.startswith(base_path) and os.path.exists(element):
269+
file_bytes = os.path.getsize(element)
270+
if file_bytes == 0:
271+
raise RuntimeError(f"The file {element} has 0 bytes!")
272+
num_bytes += file_bytes
273+
return num_bytes
274+
275+
257276
def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
258277
"""Computes the total size in bytes of all file paths for every datastructure in the given list."""
259278
item_sizes = []
260-
for item in items:
261-
flattened_item, _ = tree_flatten(item)
262-
263-
num_bytes = 0
264-
for element in flattened_item:
265-
if isinstance(element, str) and element.startswith(base_path) and os.path.exists(element):
266-
file_bytes = os.path.getsize(element)
267-
if file_bytes == 0:
268-
raise RuntimeError(f"The file {element} has 0 bytes!")
269-
num_bytes += file_bytes
270-
item_sizes.append(num_bytes)
279+
280+
cpu_count = os.cpu_count() or 1
281+
282+
# Parallelize to accelerate retrieving the number of file bytes to read for each item
283+
with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_count * 2 if cpu_count > 4 else cpu_count) as executor:
284+
futures = [executor.submit(_get_num_bytes, item, base_path) for item in items]
285+
for future in futures:
286+
item_sizes.append(future.result())
271287
return item_sizes
272288

273289

@@ -358,7 +374,7 @@ def _loop(self) -> None:
358374
for uploader in self.uploaders:
359375
uploader.join()
360376

361-
if self.remove and self.input_dir.path is not None:
377+
if self.remove:
362378
assert self.remover
363379
self.remove_queue.put(None)
364380
self.remover.join()
@@ -487,7 +503,7 @@ def _start_downloaders(self) -> None:
487503
self.to_download_queues[downloader_index].put(None)
488504

489505
def _start_remover(self) -> None:
490-
if not self.remove or self.input_dir.path is None:
506+
if not self.remove:
491507
return
492508

493509
self.remover = Process(
@@ -696,9 +712,9 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
696712
if obj.scheme == "s3":
697713
s3 = S3Client()
698714
s3.client.upload_file(
699-
local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath))
715+
local_filepath, obj.netloc, os.path.join(str(obj.path).lstrip("/"), os.path.basename(local_filepath))
700716
)
701-
elif os.path.isdir(output_dir.path):
717+
elif output_dir.path and os.path.isdir(output_dir.path):
702718
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
703719

704720
if num_nodes == 1 or node_rank is None:
@@ -710,16 +726,16 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
710726
if num_nodes == node_rank + 1:
711727
# Get the index file locally
712728
for node_rank in range(num_nodes - 1):
713-
remote_filepath = os.path.join(
714-
output_dir.url if output_dir.url else output_dir.path, f"{node_rank}-{_INDEX_FILENAME}"
715-
)
729+
output_dir_path = output_dir.url if output_dir.url else output_dir.path
730+
assert output_dir_path
731+
remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}")
716732
node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath))
717733
if obj.scheme == "s3":
718734
obj = parse.urlparse(remote_filepath)
719735
_wait_for_file_to_exist(s3, obj)
720736
with open(node_index_filepath, "wb") as f:
721737
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
722-
elif os.path.isdir(output_dir.path):
738+
elif output_dir.path and os.path.isdir(output_dir.path):
723739
shutil.copyfile(remote_filepath, node_index_filepath)
724740

725741
merge_cache = Cache(cache_dir, chunk_bytes=1)

src/lightning/data/streaming/dataset.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import hashlib
1515
import os
16-
from dataclasses import dataclass
1716
from time import time
1817
from typing import Any, Dict, List, Optional, Tuple, Union
1918

@@ -24,24 +23,21 @@
2423
from lightning.data.streaming.constants import (
2524
_DEFAULT_CACHE_DIR,
2625
_INDEX_FILENAME,
27-
_LIGHTNING_CLOUD_LATEST,
2826
)
2927
from lightning.data.streaming.item_loader import BaseItemLoader
28+
from lightning.data.streaming.resolver import Dir, _resolve_dir
3029
from lightning.data.streaming.sampler import ChunkedIndex
3130
from lightning.data.streaming.serializers import Serializer
3231
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
3332
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv
3433

35-
if _LIGHTNING_CLOUD_LATEST:
36-
from lightning_cloud.resolver import Dir, _resolve_dir
37-
3834

3935
class StreamingDataset(IterableDataset):
4036
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
4137

4238
def __init__(
4339
self,
44-
input_dir: Union[str, "RemoteDir"],
40+
input_dir: Union[str, "Dir"],
4541
item_loader: Optional[BaseItemLoader] = None,
4642
shuffle: bool = False,
4743
drop_last: bool = False,
@@ -66,12 +62,10 @@ def __init__(
6662
if not isinstance(shuffle, bool):
6763
raise ValueError(f"Shuffle should be a boolean. Found {shuffle}")
6864

69-
if isinstance(input_dir, RemoteDir):
70-
input_dir = Dir(path=input_dir.cache_dir, url=input_dir.remote)
71-
7265
input_dir = _resolve_dir(input_dir)
7366

7467
self.input_dir = input_dir
68+
7569
self.item_loader = item_loader
7670
self.shuffle: bool = shuffle
7771
self.drop_last = drop_last
@@ -368,8 +362,8 @@ def _validate_state_dict(self) -> None:
368362
)
369363

370364

371-
def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
372-
hash_object = hashlib.md5(input_dir.encode())
365+
def _try_create_cache_dir(input_dir: Optional[str], shard_rank: int = 0) -> Optional[str]:
366+
hash_object = hashlib.md5((input_dir or "").encode())
373367
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
374368
cache_dir = os.path.join(_DEFAULT_CACHE_DIR, hash_object.hexdigest(), str(shard_rank))
375369
os.makedirs(cache_dir, exist_ok=True)
@@ -379,7 +373,7 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
379373
return cache_dir
380374

381375

382-
def _should_replace_path(path: str) -> bool:
376+
def _should_replace_path(path: Optional[str]) -> bool:
383377
"""Whether the input path is a special path to be replaced."""
384378
if path is None or path == "":
385379
return True
@@ -391,14 +385,6 @@ def _is_in_dataloader_worker() -> bool:
391385
return get_worker_info() is not None
392386

393387

394-
@dataclass
395-
class RemoteDir:
396-
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""
397-
398-
cache_dir: str
399-
remote: str
400-
401-
402388
def is_integer(value: str) -> bool:
403389
try:
404390
int(value)

0 commit comments

Comments
 (0)