Skip to content

Commit 80efc01

Browse files
tchatonpre-commit-ci[bot]thomas
authored andcommitted
Add dataset creation (#18940)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas <[email protected]> (cherry picked from commit faa64c5)
1 parent f218819 commit 80efc01

File tree

8 files changed

+112
-21
lines changed

8 files changed

+112
-21
lines changed

requirements/app/app.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud == 0.5.48 # Must be pinned to ensure compatibility
1+
lightning-cloud == 0.5.50 # Must be pinned to ensure compatibility
22
packaging
33
typing-extensions >=4.0.0, <4.8.0
44
deepdiff >=5.7.0, <6.6.0

src/lightning/data/streaming/cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from lightning.data.datasets.env import _DistributedEnv
2020
from lightning.data.streaming.constants import (
2121
_INDEX_FILENAME,
22-
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48,
22+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50,
2323
_TORCH_GREATER_EQUAL_2_1_0,
2424
)
2525
from lightning.data.streaming.item_loader import BaseItemLoader
@@ -29,7 +29,7 @@
2929

3030
logger = logging.Logger(__name__)
3131

32-
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
32+
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
3333
from lightning_cloud.resolver import _resolve_dir
3434

3535

@@ -67,8 +67,8 @@ def __init__(
6767
if not _TORCH_GREATER_EQUAL_2_1_0:
6868
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
6969

70-
if not _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
71-
raise ModuleNotFoundError("Lightning Cloud 0.5.48 or higher is required to use the cache.")
70+
if not _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
71+
raise ModuleNotFoundError("Lightning Cloud 0.5.50 or higher is required to use the cache.")
7272

7373
input_dir = _resolve_dir(input_dir)
7474
self._cache_dir = input_dir.path

src/lightning/data/streaming/config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from lightning.data.streaming.sampler import ChunkedIndex
2222

2323
if _TORCH_GREATER_EQUAL_2_1_0:
24-
from torch.utils._pytree import treespec_loads
24+
from torch.utils._pytree import tree_unflatten, treespec_loads
2525

2626

2727
class ChunksConfig:
@@ -83,6 +83,24 @@ def data_format(self) -> Any:
8383
raise RuntimeError("The config should be defined.")
8484
return self._config["data_format"]
8585

86+
@property
87+
def data_format_unflattened(self) -> Any:
88+
if self._config is None:
89+
raise RuntimeError("The config should be defined.")
90+
return tree_unflatten(self._config["data_format"], self._config["data_spec"])
91+
92+
@property
93+
def compression(self) -> Any:
94+
if self._config is None:
95+
raise RuntimeError("The config should be defined.")
96+
return self._config["compression"]
97+
98+
@property
99+
def chunk_bytes(self) -> int:
100+
if self._config is None:
101+
raise RuntimeError("The config should be defined.")
102+
return self._config["chunk_bytes"]
103+
86104
@property
87105
def config(self) -> Dict[str, Any]:
88106
if self._config is None:

src/lightning/data/streaming/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# This is required for full pytree serialization / deserialization support
2222
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
2323
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
24-
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48 = RequirementCache("lightning-cloud>=0.5.48")
24+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50 = RequirementCache("lightning-cloud>=0.5.50")
2525
_BOTO3_AVAILABLE = RequirementCache("boto3")
2626

2727
# DON'T CHANGE ORDER

src/lightning/data/streaming/data_processor.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import json
12
import logging
23
import os
34
import signal
45
import tempfile
56
import traceback
67
import types
78
from abc import abstractmethod
9+
from dataclasses import dataclass
810
from multiprocessing import Process, Queue
911
from queue import Empty
1012
from shutil import copyfile, rmtree
@@ -23,7 +25,7 @@
2325
_BOTO3_AVAILABLE,
2426
_DEFAULT_FAST_DEV_RUN_ITEMS,
2527
_INDEX_FILENAME,
26-
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48,
28+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50,
2729
_TORCH_GREATER_EQUAL_2_1_0,
2830
)
2931
from lightning.fabric.accelerators.cuda import is_cuda_available
@@ -35,10 +37,12 @@
3537
from lightning.fabric.utilities.distributed import group as _group
3638

3739
if _TORCH_GREATER_EQUAL_2_1_0:
38-
from torch.utils._pytree import tree_flatten, tree_unflatten
40+
from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads
3941

40-
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
42+
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
43+
from lightning_cloud.openapi import V1DatasetType
4144
from lightning_cloud.resolver import _resolve_dir
45+
from lightning_cloud.utils.dataset import _create_dataset
4246

4347

4448
if _BOTO3_AVAILABLE:
@@ -191,7 +195,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
191195
)
192196
except Exception as e:
193197
print(e)
194-
if os.path.isdir(output_dir.path):
198+
elif os.path.isdir(output_dir.path):
195199
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
196200
else:
197201
raise ValueError(f"The provided {output_dir.path} isn't supported.")
@@ -506,6 +510,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
506510
Process.__init__(self)
507511

508512

513+
@dataclass
514+
class _Result:
515+
size: Optional[int] = None
516+
num_bytes: Optional[str] = None
517+
data_format: Optional[str] = None
518+
compression: Optional[str] = None
519+
num_chunks: Optional[int] = None
520+
num_bytes_per_chunk: Optional[List[int]] = None
521+
522+
509523
T = TypeVar("T")
510524

511525

@@ -545,8 +559,8 @@ def listdir(self, path: str) -> List[str]:
545559
def __init__(self) -> None:
546560
self._name: Optional[str] = None
547561

548-
def _done(self, delete_cached_files: bool, output_dir: Dir) -> None:
549-
pass
562+
def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result:
563+
return _Result(size=size)
550564

551565

552566
class DataChunkRecipe(DataRecipe):
@@ -576,7 +590,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
576590
def prepare_item(self, item_metadata: T) -> Any: # type: ignore
577591
"""The return of this `prepare_item` method is persisted in chunked binary files."""
578592

579-
def _done(self, delete_cached_files: bool, output_dir: Dir) -> None:
593+
def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result:
580594
num_nodes = _get_num_nodes()
581595
cache_dir = _get_cache_dir()
582596

@@ -589,6 +603,26 @@ def _done(self, delete_cached_files: bool, output_dir: Dir) -> None:
589603
merge_cache._merge_no_wait(node_rank if num_nodes > 1 else None)
590604
self._upload_index(output_dir, cache_dir, num_nodes, node_rank)
591605

606+
if num_nodes == node_rank + 1:
607+
with open(os.path.join(cache_dir, _INDEX_FILENAME)) as f:
608+
config = json.load(f)
609+
610+
size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
611+
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
612+
data_format = tree_unflatten(config["config"]["data_format"], treespec_loads(config["config"]["data_spec"]))
613+
614+
return _Result(
615+
size=size,
616+
num_bytes=num_bytes,
617+
data_format=data_format,
618+
compression=config["config"]["compression"],
619+
num_chunks=len(config["chunks"]),
620+
num_bytes_per_chunk=[c["chunk_size"] for c in config["chunks"]],
621+
)
622+
return _Result(
623+
size=size,
624+
)
625+
592626
def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_rank: Optional[int]) -> None:
593627
"""This method upload the index file to the remote cloud directory."""
594628
if output_dir.path is None and output_dir.url is None:
@@ -764,13 +798,31 @@ def run(self, data_recipe: DataRecipe) -> None:
764798
has_failed = True
765799
break
766800

801+
num_nodes = _get_num_nodes()
767802
# TODO: Understand why it hangs.
768-
if _get_num_nodes() == 1:
803+
if num_nodes == 1:
769804
for w in self.workers:
770805
w.join(0)
771806

772807
print("Workers are finished.")
773-
data_recipe._done(self.delete_cached_files, self.output_dir)
808+
result = data_recipe._done(num_items, self.delete_cached_files, self.output_dir)
809+
810+
if num_nodes == _get_node_rank() + 1:
811+
_create_dataset(
812+
input_dir=self.input_dir.path,
813+
storage_dir=self.output_dir.path,
814+
dataset_type=V1DatasetType.CHUNKED
815+
if isinstance(data_recipe, DataChunkRecipe)
816+
else V1DatasetType.TRANSFORMED,
817+
empty=False,
818+
size=result.size,
819+
num_bytes=result.num_bytes,
820+
data_format=result.data_format,
821+
compression=result.compression,
822+
num_chunks=result.num_chunks,
823+
num_bytes_per_chunk=result.num_bytes_per_chunk,
824+
)
825+
774826
print("Finished data processing!")
775827

776828
# TODO: Understand why it is required to avoid long shutdown.

src/lightning/data/streaming/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
2121
from lightning.data.streaming import Cache
22-
from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48
22+
from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50
2323
from lightning.data.streaming.item_loader import BaseItemLoader
2424
from lightning.data.streaming.sampler import ChunkedIndex
2525
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
2626

27-
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
27+
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
2828
from lightning_cloud.resolver import _resolve_dir
2929

3030

src/lightning/data/streaming/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from types import FunctionType
1919
from typing import Any, Callable, Optional, Sequence, Union
2020

21-
from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48, _TORCH_GREATER_EQUAL_2_1_0
21+
from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50, _TORCH_GREATER_EQUAL_2_1_0
2222
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
2323

24-
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
24+
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
2525
from lightning_cloud.resolver import _assert_dir_has_index_file, _assert_dir_is_empty, _execute, _resolve_dir
2626

2727
if _TORCH_GREATER_EQUAL_2_1_0:

tests/tests_data/streaming/test_data_processor.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def copy_file(local_filepath, *args):
102102
called = True
103103
from shutil import copyfile
104104

105-
copyfile(local_filepath, os.path.join(remote_output_dir.path, os.path.basename(local_filepath)))
105+
copyfile(local_filepath, os.path.join(remote_output_dir, os.path.basename(local_filepath)))
106106

107107
s3_client.client.upload_file = copy_file
108108

@@ -420,8 +420,14 @@ def _broadcast_object(self, obj: Any) -> Any:
420420
def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, monkeypatch):
421421
"""This test ensures the data optimizer works in a fully distributed settings."""
422422

423+
seed_everything(42)
424+
423425
monkeypatch.setattr(data_processor_module.os, "_exit", mock.MagicMock())
424426

427+
_create_dataset_mock = mock.MagicMock()
428+
429+
monkeypatch.setattr(data_processor_module, "_create_dataset", _create_dataset_mock)
430+
425431
from PIL import Image
426432

427433
input_dir = os.path.join(tmpdir, "dataset")
@@ -501,6 +507,21 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
501507

502508
assert sorted(os.listdir(remote_output_dir)) == expected
503509

510+
_create_dataset_mock.assert_called()
511+
512+
assert _create_dataset_mock._mock_mock_calls[0].kwargs == {
513+
"input_dir": str(input_dir),
514+
"storage_dir": str(remote_output_dir),
515+
"dataset_type": "CHUNKED",
516+
"empty": False,
517+
"size": 30,
518+
"num_bytes": 26657,
519+
"data_format": "jpeg",
520+
"compression": None,
521+
"num_chunks": 16,
522+
"num_bytes_per_chunk": [2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2],
523+
}
524+
504525

505526
class TextTokenizeRecipe(DataChunkRecipe):
506527
def prepare_structure(self, input_dir: str) -> List[Any]:

0 commit comments

Comments
 (0)