Skip to content

Commit 452f434

Browse files
tchatonlexierule
authored andcommitted
Lightning Data: Refactor files (#19424)
(cherry picked from commit ac9d63f)
1 parent 8aad87e commit 452f434

File tree

11 files changed

+94
-90
lines changed

11 files changed

+94
-90
lines changed

src/lightning/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from lightning.data.processing.functions import map, optimize, walk
12
from lightning.data.streaming.combined import CombinedStreamingDataset
23
from lightning.data.streaming.dataloader import StreamingDataLoader
34
from lightning.data.streaming.dataset import StreamingDataset
4-
from lightning.data.streaming.functions import map, optimize, walk
55

66
__all__ = [
77
"LightningDataset",

src/lightning/data/streaming/functions.py renamed to src/lightning/data/processing/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222

2323
import torch
2424

25+
from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
2526
from lightning.data.processing.readers import BaseReader
2627
from lightning.data.streaming.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
27-
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
2828
from lightning.data.streaming.resolver import (
2929
Dir,
3030
_assert_dir_has_index_file,

src/lightning/data/processing/readers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from lightning_utilities.core.imports import RequirementCache
77

8-
from lightning.data.streaming.shuffle import _associate_chunks_and_internals_to_ranks
98
from lightning.data.utilities.env import _DistributedEnv
9+
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks
1010

1111
_POLARS_AVAILABLE = RequirementCache("polars")
1212
_PYARROW_AVAILABLE = RequirementCache("pyarrow")

src/lightning/data/streaming/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,14 @@
1313

1414
from lightning.data.streaming.cache import Cache
1515
from lightning.data.streaming.combined import CombinedStreamingDataset
16-
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
1716
from lightning.data.streaming.dataloader import StreamingDataLoader
1817
from lightning.data.streaming.dataset import StreamingDataset
1918
from lightning.data.streaming.item_loader import TokensLoader
2019

2120
__all__ = [
2221
"Cache",
23-
"DataProcessor",
2422
"StreamingDataset",
2523
"CombinedStreamingDataset",
2624
"StreamingDataLoader",
27-
"DataTransformRecipe",
28-
"DataChunkRecipe",
2925
"TokensLoader",
3026
]

src/lightning/data/streaming/shuffle.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from lightning.data.streaming import Cache
2121
from lightning.data.utilities.env import _DistributedEnv
22+
from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle
2223

2324

2425
class Shuffle(ABC):
@@ -129,76 +130,3 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
129130

130131
def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
131132
return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist()
132-
133-
134-
def _intra_node_chunk_shuffle(
135-
distributed_env: _DistributedEnv,
136-
chunks_per_ranks: List[List[int]],
137-
seed: int,
138-
current_epoch: int,
139-
) -> List[int]:
140-
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
141-
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
142-
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend(
143-
chunks_per_rank
144-
)
145-
146-
# shuffle the chunks associated to the node
147-
for i in range(len(chunk_indexes_per_nodes)):
148-
# permute the indexes within the node
149-
chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation(
150-
chunk_indexes_per_nodes[i]
151-
)
152-
153-
return [index for chunks in chunk_indexes_per_nodes for index in chunks]
154-
155-
156-
def _associate_chunks_and_internals_to_ranks(
157-
distributed_env: _DistributedEnv,
158-
indexes: Any,
159-
chunk_intervals: Any,
160-
drop_last: bool,
161-
) -> Tuple[List[List[int]], List[Any]]:
162-
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
163-
num_items_per_ranks: List[int] = [
164-
num_items // distributed_env.world_size + num_items % distributed_env.world_size
165-
if rank == distributed_env.world_size - 1 and not drop_last
166-
else num_items // distributed_env.world_size
167-
for rank in range(distributed_env.world_size)
168-
]
169-
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
170-
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
171-
172-
# 4. Assign the chunk & intervals to each rank
173-
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
174-
rank = 0
175-
176-
while True:
177-
if rank == len(num_items_per_ranks):
178-
break
179-
180-
items_left_to_assign = num_items_per_ranks[rank]
181-
182-
if items_left_to_assign == 0:
183-
rank += 1
184-
continue
185-
186-
items_in_chunk = chunk_interval[-1] - chunk_interval[0]
187-
188-
if items_in_chunk == 0:
189-
break
190-
191-
if items_in_chunk > items_left_to_assign:
192-
chunks_per_ranks[rank].append(chunk_index)
193-
begin, end = chunk_interval
194-
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
195-
chunk_interval = (begin + items_left_to_assign, end)
196-
num_items_per_ranks[rank] = 0
197-
rank += 1
198-
else:
199-
chunks_per_ranks[rank].append(chunk_index)
200-
intervals_per_ranks[rank].append(chunk_interval)
201-
num_items_per_ranks[rank] -= items_in_chunk
202-
break
203-
204-
return chunks_per_ranks, intervals_per_ranks
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Any, List, Tuple
2+
3+
import numpy as np
4+
5+
from lightning.data.utilities.env import _DistributedEnv
6+
7+
8+
def _intra_node_chunk_shuffle(
9+
distributed_env: _DistributedEnv,
10+
chunks_per_ranks: List[List[int]],
11+
seed: int,
12+
current_epoch: int,
13+
) -> List[int]:
14+
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
15+
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
16+
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend(
17+
chunks_per_rank
18+
)
19+
20+
# shuffle the chunks associated to the node
21+
for i in range(len(chunk_indexes_per_nodes)):
22+
# permute the indexes within the node
23+
chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation(
24+
chunk_indexes_per_nodes[i]
25+
)
26+
27+
return [index for chunks in chunk_indexes_per_nodes for index in chunks]
28+
29+
30+
def _associate_chunks_and_internals_to_ranks(
31+
distributed_env: _DistributedEnv,
32+
indexes: Any,
33+
chunk_intervals: Any,
34+
drop_last: bool,
35+
) -> Tuple[List[List[int]], List[Any]]:
36+
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
37+
num_items_per_ranks: List[int] = [
38+
num_items // distributed_env.world_size + num_items % distributed_env.world_size
39+
if rank == distributed_env.world_size - 1 and not drop_last
40+
else num_items // distributed_env.world_size
41+
for rank in range(distributed_env.world_size)
42+
]
43+
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
44+
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
45+
46+
# 4. Assign the chunk & intervals to each rank
47+
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
48+
rank = 0
49+
50+
while True:
51+
if rank == len(num_items_per_ranks):
52+
break
53+
54+
items_left_to_assign = num_items_per_ranks[rank]
55+
56+
if items_left_to_assign == 0:
57+
rank += 1
58+
continue
59+
60+
items_in_chunk = chunk_interval[-1] - chunk_interval[0]
61+
62+
if items_in_chunk == 0:
63+
break
64+
65+
if items_in_chunk > items_left_to_assign:
66+
chunks_per_ranks[rank].append(chunk_index)
67+
begin, end = chunk_interval
68+
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
69+
chunk_interval = (begin + items_left_to_assign, end)
70+
num_items_per_ranks[rank] = 0
71+
rank += 1
72+
else:
73+
chunks_per_ranks[rank].append(chunk_index)
74+
intervals_per_ranks[rank].append(chunk_interval)
75+
num_items_per_ranks[rank] -= items_in_chunk
76+
break
77+
78+
return chunks_per_ranks, intervals_per_ranks

tests/tests_data/streaming/test_data_processor.py renamed to tests/tests_data/processing/test_data_processor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
import pytest
1111
import torch
1212
from lightning import seed_everything
13-
from lightning.data.streaming import data_processor as data_processor_module
14-
from lightning.data.streaming import functions, resolver
15-
from lightning.data.streaming.cache import Cache, Dir
16-
from lightning.data.streaming.data_processor import (
13+
from lightning.data.processing import data_processor as data_processor_module
14+
from lightning.data.processing import functions
15+
from lightning.data.processing.data_processor import (
1716
DataChunkRecipe,
1817
DataProcessor,
1918
DataTransformRecipe,
@@ -26,7 +25,9 @@
2625
_wait_for_disk_usage_higher_than_threshold,
2726
_wait_for_file_to_exist,
2827
)
29-
from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize
28+
from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize
29+
from lightning.data.streaming import resolver
30+
from lightning.data.streaming.cache import Cache, Dir
3031
from lightning_utilities.core.imports import RequirementCache
3132

3233
_PIL_AVAILABLE = RequirementCache("PIL")
@@ -162,7 +163,7 @@ def fn(*_, **__):
162163

163164

164165
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
165-
@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold")
166+
@mock.patch("lightning.data.processing.data_processor._wait_for_disk_usage_higher_than_threshold")
166167
def test_download_data_target(wait_for_disk_usage_higher_than_threshold_mock, tmpdir):
167168
input_dir = os.path.join(tmpdir, "input_dir")
168169
os.makedirs(input_dir, exist_ok=True)
@@ -201,7 +202,7 @@ def fn(*_, **__):
201202

202203
def test_wait_for_disk_usage_higher_than_threshold():
203204
disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)])
204-
with mock.patch("lightning.data.streaming.data_processor.shutil.disk_usage", disk_usage_mock):
205+
with mock.patch("lightning.data.processing.data_processor.shutil.disk_usage", disk_usage_mock):
205206
_wait_for_disk_usage_higher_than_threshold("/", 10, sleep_time=0)
206207
assert disk_usage_mock.call_count == 3
207208

tests/tests_data/streaming/test_functions.py renamed to tests/tests_data/processing/test_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66
from lightning.data import walk
7-
from lightning.data.streaming.functions import _get_input_dir
7+
from lightning.data.processing.functions import _get_input_dir
88

99

1010
@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")

tests/tests_data/streaming/test_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import pytest
2121
import torch
2222
from lightning import seed_everything
23-
from lightning.data.streaming import Cache, functions
23+
from lightning.data.processing import functions
24+
from lightning.data.streaming import Cache
2425
from lightning.data.streaming.dataloader import StreamingDataLoader
2526
from lightning.data.streaming.dataset import (
2627
_INDEX_FILENAME,

0 commit comments

Comments
 (0)