Skip to content

Commit 7d9273c

Browse files
tchatonjustusschockpre-commit-ci[bot]thomas
authored andcommitted
Improve DatasetOptimizer API (#18827)
Co-authored-by: Justus Schock <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas <[email protected]> (cherry picked from commit e59dc41)
1 parent 96190ce commit 7d9273c

File tree

9 files changed

+241
-93
lines changed

9 files changed

+241
-93
lines changed

requirements/app/app.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud ==0.5.42 # Must be pinned to ensure compatibility
1+
lightning-cloud ==0.5.43 # Must be pinned to ensure compatibility
22
packaging
33
typing-extensions >=4.0.0, <4.8.0
44
deepdiff >=5.7.0, <6.6.0

src/lightning/app/cli/commands/cp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
113113
else:
114114
upload_paths = [local_src]
115115

116-
upload_urls = []
116+
_upload_urls = []
117117

118118
clusters = client.projects_service_list_project_cluster_bindings(project_id)
119119

@@ -129,9 +129,11 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
129129
body=ProjectIdStorageBody(cluster_id=cluster.cluster_id, filename=filename),
130130
async_req=True,
131131
)
132-
upload_urls.append(response)
132+
_upload_urls.append(response)
133133

134-
upload_urls = [upload_url.get().upload_url for upload_url in upload_urls]
134+
upload_urls = []
135+
for upload_url in _upload_urls:
136+
upload_urls.extend(upload_url.get().urls)
135137

136138
live.stop()
137139

src/lightning/data/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
from lightning.data.datasets import LightningDataset, LightningIterableDataset
22
from lightning.data.streaming.dataloader import StreamingDataLoader
33
from lightning.data.streaming.dataset import StreamingDataset
4+
from lightning.data.streaming.dataset_optimizer import DatasetOptimizer
45

5-
__all__ = ["LightningDataset", "StreamingDataset", "StreamingDataLoader", "LightningIterableDataset"]
6+
__all__ = [
7+
"LightningDataset",
8+
"StreamingDataset",
9+
"StreamingDataLoader",
10+
"LightningIterableDataset",
11+
"DatasetOptimizer",
12+
]

src/lightning/data/streaming/dataset.py

Lines changed: 110 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,135 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from typing import Any, Literal, Optional, Union
14+
from typing import Any, List, Literal, Optional, Union
1515

16-
from torch.utils.data import Dataset
16+
import numpy as np
17+
from torch.utils.data import IterableDataset
1718

19+
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
1820
from lightning.data.streaming import Cache
21+
from lightning.data.streaming.item_loader import BaseItemLoader
22+
from lightning.data.streaming.sampler import ChunkedIndex
1923

2024

21-
class StreamingDataset(Dataset):
25+
class StreamingDataset(IterableDataset):
2226
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
2327

2428
def __init__(
25-
self, name: str, version: Optional[Union[int, Literal["latest"]]] = "latest", cache_dir: Optional[str] = None
29+
self,
30+
name: str,
31+
version: Optional[Union[int, Literal["latest"]]] = "latest",
32+
cache_dir: Optional[str] = None,
33+
item_loader: Optional[BaseItemLoader] = None,
34+
shuffle: bool = True,
35+
seed: int = 42,
2636
) -> None:
2737
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
2838
2939
Arguments:
3040
name: The name of the optimised dataset.
3141
version: The version of the dataset to use.
3242
cache_dir: The cache dir where the data would be stored.
43+
item_loader: The logic to load an item from a chunk.
44+
shuffle: Whether to shuffle the data.
45+
seed: Random seed for shuffling.
3346
3447
"""
3548
super().__init__()
36-
self.cache = Cache(name=name, version=version, cache_dir=cache_dir)
49+
self.cache = Cache(name=name, version=version, cache_dir=cache_dir, item_loader=item_loader, chunk_bytes=1)
50+
51+
self.cache._reader._try_load_config()
52+
53+
if not self.cache.filled:
54+
raise ValueError(f"The provided dataset `{name}` isn't filled up.")
55+
56+
self.shuffle = shuffle
57+
self.distributed_env = _DistributedEnv.detect()
58+
self.worker_env: Optional[_WorkerEnv] = None
59+
60+
chunk_intervals = self.cache.get_chunk_interval()
61+
self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
62+
63+
self.worker_chunks: List[int] = []
64+
self.worker_intervals: List[List[int]] = []
65+
self.current_indexes: List[int] = []
66+
self.chunk_index = 0
67+
self.index = 0
68+
self.has_triggered_download = False
69+
self.min_items_per_replica: Optional[int] = None
70+
self.seed = seed
71+
self.num_iter = 0
72+
self.random_state = None
3773

3874
def __len__(self) -> int:
39-
return len(self.cache)
75+
return self.L
76+
77+
def __iter__(self) -> "StreamingDataset":
78+
self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore
79+
chunk_intervals = self.cache.get_chunk_interval()
80+
indexes = range(len(chunk_intervals))
81+
shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes)
82+
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
83+
84+
chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)]
85+
intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)]
86+
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
87+
replica_index = index % self.distributed_env.world_size
88+
chunks_per_replica[replica_index].append(chunk_index)
89+
intervals_per_replica[replica_index].append(chunk_interval)
90+
91+
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
92+
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
93+
94+
if self.worker_env is None:
95+
self.worker_env = _WorkerEnv.detect()
96+
97+
self.worker_chunks = []
98+
self.worker_intervals = []
99+
100+
for i, (chunk_index, chunk_interval) in enumerate(zip(current_chunks, current_intervals)):
101+
if i % self.worker_env.world_size != self.worker_env.rank:
102+
continue
103+
self.worker_chunks.append(chunk_index)
104+
self.worker_intervals.append(chunk_interval)
105+
106+
self.current_indexes = []
107+
self.chunk_index = 0
108+
self.num_iter += 1
109+
110+
return self
111+
112+
def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
113+
if isinstance(index, int):
114+
index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index))
115+
return self.cache[index]
116+
117+
def __next__(self) -> Any:
118+
# Lazily re-populate the interval to reduce memory usage.
119+
if len(self.current_indexes) == 0:
120+
if self.chunk_index == len(self.worker_intervals):
121+
raise StopIteration
122+
123+
interval = self.worker_intervals[self.chunk_index]
124+
current_indexes = np.arange(0, interval[1] - interval[0])
125+
if self.shuffle:
126+
current_indexes = self.random_state.permutation(current_indexes)
127+
self.current_indexes = current_indexes.tolist()
128+
self.chunk_index += 1
129+
130+
# Get the first index
131+
index = self.current_indexes.pop(0)
132+
133+
# Call the `__getitem__` method.
134+
data = self.__getitem__(
135+
ChunkedIndex(
136+
index=index,
137+
chunk_index=self.worker_chunks[self.chunk_index - 1],
138+
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
139+
)
140+
)
40141

41-
def __getitem__(self, idx: int) -> Any:
42-
return self.cache[idx]
142+
self.has_triggered_download = True
143+
self.index += 1
43144

44-
def getitem(self, obj: Any) -> Any:
45-
"""Override the getitem with your own logic to transform the cache object."""
46-
return obj
145+
return data

src/lightning/data/streaming/dataset_optimizer.py

Lines changed: 45 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import signal
44
import traceback
55
import types
6-
from abc import ABC, abstractmethod
76
from enum import Enum
87
from multiprocessing import Process, Queue
98
from pathlib import Path
109
from queue import Empty
1110
from shutil import copyfile
11+
from textwrap import dedent
1212
from threading import Thread
1313
from time import sleep, time
14-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
14+
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Tuple, TypeVar, runtime_checkable
1515
from urllib import parse
1616

1717
from tqdm.auto import tqdm
@@ -167,7 +167,7 @@ def __init__(
167167
start_index: int,
168168
dataset_name: str,
169169
node_rank: int,
170-
dataset_optimizer: "DatasetOptimizer",
170+
prepare_item: Callable,
171171
src_dir: str,
172172
remote_src_dir: str,
173173
remote_dst_dir: Optional[str],
@@ -187,7 +187,7 @@ def __init__(
187187
self.start_index = start_index
188188
self.dataset_name = dataset_name
189189
self.node_rank = node_rank
190-
self.prepare_item = dataset_optimizer.prepare_item
190+
self.prepare_item = prepare_item
191191
self.src_dir = src_dir
192192
self.remote_src_dir = remote_src_dir
193193
self.remote_dst_dir = remote_dst_dir
@@ -432,57 +432,21 @@ class WorkerType(Enum):
432432
PROCESS = "process"
433433

434434

435-
class DatasetOptimizer(ABC):
436-
@abstractmethod
437-
def prepare_dataset_structure(self, src_dir: str, filepaths: List[str]) -> List[Any]:
438-
"""This function is meant to return a list of item metadata. Each item metadata should be enough to prepare a
439-
single item when called with the prepare_item.
435+
T = TypeVar("T")
440436

441-
Example::
442437

443-
# For a classification use case
444-
445-
def prepare_dataset_structure(self, src_dir, filepaths)
446-
import numpy as np
447-
448-
filepaths = ['class_a/file_1.ext', ..., 'class_b/file_1.ext', ...]
449-
classes = np.unique([filepath.split("/")[0] for filepath in filepaths])
450-
classes_to_idx_map = {c: idx for idx, c in enumerate(classes)}
451-
452-
# Return pair with the filepath to the obj and its class
453-
# [('class_a/file_1.ext', 0), ... ('class_b/file_1.ext', 1)]
454-
return [(filepath, classes_to_idx_map[filepath.split("/")[0]]) for filepath in filepaths]
455-
456-
Example::
457-
458-
# For a image segmentation use case
459-
460-
def prepare_dataset_structure(self, src_dir, filepaths)
461-
import numpy as np
462-
463-
filepaths = ['file_1.JPEG', 'file_1.mask', .... 'file_N.JPEG', 'file_N.mask', ...]
464-
465-
# [('file_1.JPEG', 'file_1.mask'), ... ('file_N.JPEG', 'file_N.mask')]
466-
return [(x[i], x[i+1]) for i in range(len(filepaths) -1)]
467-
468-
def prepare_item(self, obj):
469-
image_filepath, mask_filepath = obj
470-
471-
image = load_and_resize(image_filepath)
472-
mask = load_and_resize(mask_filepath)
473-
return (image, mask)
474-
475-
"""
438+
@runtime_checkable
439+
class _OptimizableDataset(Protocol):
440+
@staticmethod
441+
def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
476442
pass
477443

478-
def prepare_item(self, metadata_item: Any) -> Any:
479-
"""Using some metadata, prepare the associated item.
444+
@staticmethod
445+
def prepare_item(item_metadata: T) -> Any:
446+
return item_metadata
480447

481-
The output of this function will be binarised
482-
483-
"""
484-
return metadata_item
485448

449+
class DatasetOptimizer:
486450
def __init__(
487451
self,
488452
name: str,
@@ -547,9 +511,29 @@ def __init__(
547511
)
548512
self.random_seed = random_seed
549513

550-
def run(self) -> None:
514+
def run(self, optimizable_dataset: _OptimizableDataset) -> None:
551515
"""The `DatasetChunker.run(...)` method is used to trigger the data processing from your dataset into
552516
chunks."""
517+
if not isinstance(optimizable_dataset, _OptimizableDataset):
518+
raise ValueError(
519+
dedent(
520+
"""The provided argument to the DatasetOptimizer.run(...) needs to have the following format:
521+
522+
Example:
523+
524+
class YourDataset:
525+
526+
@staticmethod
527+
def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
528+
return [...]
529+
530+
@staticmethod
531+
def prepare_item(item_metadata: T) -> Any:
532+
return ...
533+
"""
534+
)
535+
)
536+
553537
t0 = time()
554538
print(f"Setup started for `{self.name}` with fast_dev_run={self.fast_dev_run}.")
555539

@@ -564,7 +548,7 @@ def run(self) -> None:
564548
seed_everything(self.random_seed)
565549

566550
# Call the setup method of the user
567-
user_items = self.prepare_dataset_structure(self.src_dir, filepaths)
551+
user_items: List[Any] = optimizable_dataset.prepare_dataset_structure(self.src_dir, filepaths)
568552

569553
if not isinstance(user_items, list):
570554
raise ValueError("The setup_fn should return a list of item metadata.")
@@ -588,9 +572,9 @@ def run(self) -> None:
588572
signal.signal(signal.SIGINT, self._signal_handler)
589573

590574
if self.worker_type == WorkerType.THREAD.value:
591-
self._create_thread_workers(begins, workers_user_items)
575+
self._create_thread_workers(optimizable_dataset, begins, workers_user_items)
592576
else:
593-
self._create_process_workers(begins, workers_user_items)
577+
self._create_process_workers(optimizable_dataset, begins, workers_user_items)
594578

595579
print("Workers are ready ! Starting data processing...")
596580

@@ -634,7 +618,9 @@ def _exit_on_error(self, error: str) -> None:
634618
w.join(0)
635619
raise RuntimeError(f"We found the following error {error}.")
636620

637-
def _create_thread_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
621+
def _create_thread_workers(
622+
self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]]
623+
) -> None:
638624
current_total = 0
639625
total = sum([len(w) for w in workers_user_items])
640626
with tqdm(total=total, smoothing=0) as pbar:
@@ -649,7 +635,7 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
649635
begins[worker_idx],
650636
self.name,
651637
_get_node_rank(),
652-
self,
638+
optimizable_dataset.prepare_item,
653639
self.src_dir,
654640
self.remote_src_dir,
655641
self.remote_dst_dir,
@@ -676,7 +662,9 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
676662
if current_total == total:
677663
break
678664

679-
def _create_process_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
665+
def _create_process_workers(
666+
self, optimizable_dataset: _OptimizableDataset, begins: List[int], workers_user_items: List[List[Any]]
667+
) -> None:
680668
self.progress_queue = Queue()
681669
workers: List[DataWorkerProcess] = []
682670
stop_queues: List[Queue] = []
@@ -688,7 +676,7 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li
688676
begins[worker_idx],
689677
self.name,
690678
_get_node_rank(),
691-
self,
679+
optimizable_dataset.prepare_item,
692680
self.src_dir,
693681
self.remote_src_dir,
694682
self.remote_dst_dir,

0 commit comments

Comments
 (0)