Skip to content

Commit 75510dd

Browse files
authored
StreamingDataset: Add intra node shuffling to accelerate second epoch (#19296)
1 parent 4004f85 commit 75510dd

File tree

11 files changed

+338
-100
lines changed

11 files changed

+338
-100
lines changed

requirements/data/data.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ lightning-utilities >=0.8.0, <0.10.0
55
# to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass
66
torch >0.14.0, <2.2.0
77
lightning-cloud
8+
filelock

src/lightning/data/streaming/combined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646

4747
def __len__(self) -> int:
4848
assert self._weights
49-
return int(sum(w * len(d) for w, d in zip(self._weights, self._datasets)))
49+
return int(min([1 / w * len(d) for w, d in zip(self._weights, self._datasets) if w > 0]))
5050

5151
def __iter__(self) -> Iterator[Any]:
5252
assert self._weights

src/lightning/data/streaming/dataset.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from lightning.data.streaming.sampler import ChunkedIndex
3030
from lightning.data.streaming.serializers import Serializer
3131
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
32-
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv
32+
from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv
3333

3434

3535
class StreamingDataset(IterableDataset):
@@ -91,13 +91,9 @@ def __init__(
9191
self._state_dict: Optional[Dict[str, Any]] = None
9292

9393
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
94-
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
95-
9694
if _should_replace_path(self.input_dir.path):
97-
# FIXME: Remove the `shard_rank` from the cache_path to enable reloading chunks for the second epoch
98-
# without paying the cost of re-download
9995
cache_path = _try_create_cache_dir(
100-
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, shard_rank=env.shard_rank
96+
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url
10197
)
10298
if cache_path is not None:
10399
self.input_dir.path = cache_path
@@ -362,13 +358,13 @@ def _validate_state_dict(self) -> None:
362358
)
363359

364360

365-
def _try_create_cache_dir(input_dir: Optional[str], shard_rank: int = 0) -> Optional[str]:
361+
def _try_create_cache_dir(input_dir: Optional[str]) -> Optional[str]:
366362
hash_object = hashlib.md5((input_dir or "").encode())
367363
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
368-
cache_dir = os.path.join(_DEFAULT_CACHE_DIR, hash_object.hexdigest(), str(shard_rank))
364+
cache_dir = os.path.join(_DEFAULT_CACHE_DIR, hash_object.hexdigest())
369365
os.makedirs(cache_dir, exist_ok=True)
370366
return cache_dir
371-
cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest(), str(shard_rank))
367+
cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest())
372368
os.makedirs(cache_dir, exist_ok=True)
373369
return cache_dir
374370

src/lightning/data/streaming/downloader.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from typing import Any, Dict, List
1717
from urllib import parse
1818

19+
from filelock import FileLock, Timeout
20+
1921
from lightning.data.streaming.client import S3Client
2022

2123

@@ -50,21 +52,28 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
5052

5153
extra_args: Dict[str, Any] = {}
5254

53-
# Issue: https://github.com/boto/boto3/issues/3113
54-
self._client.client.download_file(
55-
obj.netloc,
56-
obj.path.lstrip("/"),
57-
local_filepath,
58-
ExtraArgs=extra_args,
59-
Config=TransferConfig(use_threads=False),
60-
)
55+
try:
56+
with FileLock(local_filepath + ".lock", timeout=1):
57+
if not os.path.exists(local_filepath):
58+
# Issue: https://github.com/boto/boto3/issues/3113
59+
self._client.client.download_file(
60+
obj.netloc,
61+
obj.path.lstrip("/"),
62+
local_filepath,
63+
ExtraArgs=extra_args,
64+
Config=TransferConfig(use_threads=False),
65+
)
66+
except Timeout:
67+
# another process is responsible to download that file, continue
68+
pass
6169

6270

6371
class LocalDownloader(Downloader):
6472
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
6573
if not os.path.exists(remote_filepath):
6674
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")
67-
if remote_filepath != local_filepath:
75+
76+
if remote_filepath != local_filepath and not os.path.exists(local_filepath):
6877
shutil.copy(remote_filepath, local_filepath)
6978

7079

src/lightning/data/streaming/item_loader.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,22 +181,23 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
181181
if chunk_filepath not in self._chunk_filepaths:
182182
self._chunk_filepaths[chunk_filepath] = True
183183

184-
self._load_chunk(chunk_index, chunk_filepath)
184+
if os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0:
185+
self._load_chunk(chunk_index, chunk_filepath)
185186

186187
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
187188
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
188189
del self._chunk_filepaths[chunk_filepath]
189190

190191
if chunk_filepath not in self._chunk_filepaths:
191-
first_exists = exists = os.path.exists(chunk_filepath)
192+
first_exists = exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
192193

193194
while not exists:
194195
sleep(0.1)
195-
exists = os.path.exists(chunk_filepath)
196+
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
196197

197198
# Wait to avoid any corruption when the file appears
198199
if not first_exists:
199-
sleep(0.001)
200+
sleep(0.1)
200201

201202
self._chunk_filepaths[chunk_filepath] = True
202203

src/lightning/data/streaming/reader.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
self,
5252
config: ChunksConfig,
5353
item_loader: BaseItemLoader,
54+
distributed_env: _DistributedEnv,
5455
max_cache_size: Optional[int] = None,
5556
max_pre_download: int = 2,
5657
) -> None:
@@ -59,15 +60,17 @@ def __init__(
5960
self._item_loader = item_loader
6061
self._max_pre_download = max_pre_download
6162
self._pre_download_counter = 0
63+
self._distributed_env = distributed_env
6264

6365
self._chunks_index_to_be_deleted: List[int] = []
6466
self._max_cache_size = max_cache_size
6567
self._parent_cache_dir = os.path.dirname(self._config._cache_dir)
6668
self._to_download_queue: multiprocessing.Queue = multiprocessing.Queue()
6769
self._to_delete_queue: multiprocessing.Queue = multiprocessing.Queue()
6870

69-
# FIXME: This should be divided by the number of nodes to provide a more granular support with scaling out
70-
self._delete_chunks_when_processed = self._config.num_bytes > max_cache_size if max_cache_size else False
71+
# Check whether a dataset slice fits on the node
72+
num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes
73+
self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False
7174
self._has_exited = False
7275

7376
def download(self, chunk_indexes: List[int]) -> None:
@@ -229,7 +232,9 @@ def read(self, index: ChunkedIndex) -> Any:
229232
if self._config and self._config._remote_dir:
230233
# Create and start the prepare chunks thread
231234
if self._prepare_thread is None and self._config:
232-
self._prepare_thread = PrepareChunksThread(self._config, self._item_loader, self._max_cache_size)
235+
self._prepare_thread = PrepareChunksThread(
236+
self._config, self._item_loader, self._distributed_env, self._max_cache_size
237+
)
233238
self._prepare_thread.start()
234239
if index.chunk_indexes:
235240
self._prepare_thread.download(index.chunk_indexes)

src/lightning/data/streaming/shuffle.py

Lines changed: 95 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -97,55 +97,108 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
9797
# 2. Shuffle them
9898
indexes = range(len(chunk_intervals))
9999

100-
# FIXME: Shuffling should be done only within the nodes to benefit
101-
# from cache if the dataset doesn't fit on the node.
102-
shuffled_indexes = np.random.RandomState(seed=self.seed + current_epoch).permutation(indexes)
103-
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
100+
# If we have multiple nodes, the seed_shift is constant here.
101+
# Here is why. When you are running epoch 1, we need to shuffle the chunks
102+
# and associate to each rank. This is done there.
103+
# When you are running epoch 2 or more, we need to keep the same shuffling
104+
# than in epoch 1 because shuffle a second time within the node.
105+
# This is done slighyly down this function.
106+
seed_shift = 1 if distributed_env.num_nodes > 1 else current_epoch
107+
shuffled_indexes = np.random.RandomState(seed=self.seed + seed_shift).permutation(indexes)
108+
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()
104109

105110
# 3. Compute the items budget of each rank
106-
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
107-
num_items_per_ranks: List[int] = [
108-
num_items // distributed_env.world_size + num_items % distributed_env.world_size
109-
if rank == distributed_env.world_size - 1 and not self.drop_last
110-
else num_items // distributed_env.world_size
111-
for rank in range(distributed_env.world_size)
112-
]
113-
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
114-
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
115-
116-
# 4. Assign the chunk & intervals to each rank
117-
for chunk_index, chunk_interval in zip(shuffled_indexes, shuffled_chunk_intervals):
118-
rank = 0
119-
120-
while True:
121-
if rank == len(num_items_per_ranks):
122-
break
123-
124-
items_left_to_assign = num_items_per_ranks[rank]
111+
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
112+
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last
113+
)
125114

126-
if items_left_to_assign == 0:
127-
rank += 1
128-
continue
115+
# For the first epoch, no need of further shuffling
116+
if current_epoch == 1 or distributed_env.num_nodes == 1:
117+
return chunks_per_ranks, intervals_per_ranks
129118

130-
items_in_chunk = chunk_interval[-1] - chunk_interval[0]
119+
# Perform shuffle within the nodes to avoid cache miss.
120+
# Note: It is possible for the overlapping chunks to change due to the changing order.
121+
shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, chunks_per_ranks, self.seed, current_epoch)
122+
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()
131123

132-
if items_in_chunk == 0:
133-
break
134-
135-
if items_in_chunk > items_left_to_assign:
136-
chunks_per_ranks[rank].append(chunk_index)
137-
begin, end = chunk_interval
138-
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
139-
chunk_interval = (begin + items_left_to_assign, end)
140-
num_items_per_ranks[rank] = 0
141-
rank += 1
142-
else:
143-
chunks_per_ranks[rank].append(chunk_index)
144-
intervals_per_ranks[rank].append(chunk_interval)
145-
num_items_per_ranks[rank] -= items_in_chunk
146-
break
124+
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
125+
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last
126+
)
147127

148128
return chunks_per_ranks, intervals_per_ranks
149129

150130
def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
151131
return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist()
132+
133+
134+
def _intra_node_chunk_shuffle(
135+
distributed_env: _DistributedEnv,
136+
chunks_per_ranks: List[List[int]],
137+
seed: int,
138+
current_epoch: int,
139+
) -> List[int]:
140+
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
141+
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
142+
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend(
143+
chunks_per_rank
144+
)
145+
146+
# shuffle the chunks associated to the node
147+
for i in range(len(chunk_indexes_per_nodes)):
148+
# permute the indexes within the node
149+
chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation(
150+
chunk_indexes_per_nodes[i]
151+
)
152+
153+
return [index for chunks in chunk_indexes_per_nodes for index in chunks]
154+
155+
156+
def _associate_chunks_and_internals_to_ranks(
157+
distributed_env: _DistributedEnv,
158+
indexes: Any,
159+
chunk_intervals: Any,
160+
drop_last: bool,
161+
) -> Tuple[List[List[int]], List[Any]]:
162+
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
163+
num_items_per_ranks: List[int] = [
164+
num_items // distributed_env.world_size + num_items % distributed_env.world_size
165+
if rank == distributed_env.world_size - 1 and not drop_last
166+
else num_items // distributed_env.world_size
167+
for rank in range(distributed_env.world_size)
168+
]
169+
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
170+
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
171+
172+
# 4. Assign the chunk & intervals to each rank
173+
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
174+
rank = 0
175+
176+
while True:
177+
if rank == len(num_items_per_ranks):
178+
break
179+
180+
items_left_to_assign = num_items_per_ranks[rank]
181+
182+
if items_left_to_assign == 0:
183+
rank += 1
184+
continue
185+
186+
items_in_chunk = chunk_interval[-1] - chunk_interval[0]
187+
188+
if items_in_chunk == 0:
189+
break
190+
191+
if items_in_chunk > items_left_to_assign:
192+
chunks_per_ranks[rank].append(chunk_index)
193+
begin, end = chunk_interval
194+
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
195+
chunk_interval = (begin + items_left_to_assign, end)
196+
num_items_per_ranks[rank] = 0
197+
rank += 1
198+
else:
199+
chunks_per_ranks[rank].append(chunk_index)
200+
intervals_per_ranks[rank].append(chunk_interval)
201+
num_items_per_ranks[rank] -= items_in_chunk
202+
break
203+
204+
return chunks_per_ranks, intervals_per_ranks

src/lightning/data/utilities/env.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ class _DistributedEnv:
1313
1414
"""
1515

16-
def __init__(self, world_size: int, global_rank: int):
16+
def __init__(self, world_size: int, global_rank: int, num_nodes: int):
1717
self.world_size = world_size
1818
self.global_rank = global_rank
19+
self.num_nodes = num_nodes
1920

2021
@classmethod
2122
def detect(cls) -> "_DistributedEnv":
@@ -37,7 +38,14 @@ def detect(cls) -> "_DistributedEnv":
3738
if world_size is None or world_size == -1:
3839
world_size = 1
3940

40-
return cls(world_size=world_size, global_rank=global_rank)
41+
# TODO: Add support for other accelerators
42+
num_nodes = (world_size // torch.cuda.device_count()) if torch.cuda.is_available() else 1
43+
44+
if num_nodes > 1:
45+
# validate the world size is divisble by the number of GPUs
46+
assert world_size % torch.cuda.device_count() == 0
47+
48+
return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)
4149

4250
def __repr__(self) -> str:
4351
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"
@@ -113,7 +121,8 @@ def from_args(
113121
the current training process
114122
115123
"""
116-
dist_env = _DistributedEnv(dist_world_size, global_rank)
124+
num_nodes = (dist_world_size // torch.cuda.device_count()) if torch.cuda.is_available() else 1
125+
dist_env = _DistributedEnv(dist_world_size, global_rank, num_nodes)
117126
worker_env = _WorkerEnv(num_workers, current_worker_rank)
118127
return cls(dist_env=dist_env, worker_env=worker_env)
119128

0 commit comments

Comments
 (0)