Skip to content

Commit c148282

Browse files
tchatonawaelchlithomaspre-commit-ci[bot]
authored andcommitted
Prevent leaking the thread to the workers (#18891)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: thomas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 2526c90)
1 parent b2a8ddd commit c148282

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

src/lightning/data/streaming/downloader.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
import os
14+
import shutil
1415
from abc import ABC, abstractmethod
1516
from typing import Any, Dict, List, Type
1617
from urllib import parse
@@ -63,8 +64,15 @@ def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
6364
)
6465

6566

66-
# TODO: Add fsspec support
67-
_DOWNLOADERS = {"s3://": S3Downloader}
67+
class LocalDownloader(Downloader):
68+
@classmethod
69+
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
70+
if not os.path.exists(remote_filepath):
71+
raise FileNotFoundError("The provided remote_path doesn't exist: {remote_path}")
72+
shutil.copy(remote_filepath, local_filepath)
73+
74+
75+
_DOWNLOADERS = {"s3://": S3Downloader, "": LocalDownloader}
6876

6977

7078
def get_downloader_cls(remote_dir: str) -> Type[Downloader]:

src/lightning/data/streaming/item_loader.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
5151
class PyTreeLoader(BaseItemLoader):
5252
"""The Pytree Loader is the default loader of the Cache object."""
5353

54+
def __init__(self) -> None:
55+
self._chunk_filepaths: Dict[str, bool] = {}
56+
5457
def generate_intervals(self) -> List[Tuple[int, int]]:
5558
intervals = []
5659
begin = 0
@@ -64,8 +67,10 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
6467
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> bytes:
6568
offset = (1 + (index - begin) if index >= begin else index + 1) * 4
6669

67-
while not os.path.exists(chunk_filepath):
68-
sleep(0.0001)
70+
if chunk_filepath not in self._chunk_filepaths:
71+
while not os.path.exists(chunk_filepath):
72+
sleep(0.001)
73+
self._chunk_filepaths[chunk_filepath] = True
6974

7075
with open(chunk_filepath, "rb", 0) as fp:
7176
fp.seek(offset)

src/lightning/data/streaming/reader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,8 @@ def get_chunk_intervals(self) -> List[Tuple[int, int]]:
172172
raise Exception("The reader index isn't defined.")
173173

174174
return self.config.intervals
175+
176+
def __getstate__(self) -> Dict[str, Any]:
177+
state = self.__dict__.copy()
178+
state["_prepare_thread"] = None
179+
return state

tests/tests_data/streaming/test_dataset.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,32 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
206206
assert len(process_2_1) == 611
207207

208208
assert len([i for i in process_1_1 if i in process_2_1]) == 0
209+
210+
211+
def test_streaming_dataset_deepcopy(tmpdir, monkeypatch):
212+
seed_everything(42)
213+
214+
remote_dir = os.path.join(tmpdir, "remote_dir")
215+
216+
os.makedirs(remote_dir, exist_ok=True)
217+
218+
cache = Cache(remote_dir, chunk_size=10)
219+
for i in range(10):
220+
cache[i] = i
221+
222+
cache.done()
223+
cache.merge()
224+
225+
monkeypatch.setattr(cache_module, "_find_remote_dir", lambda x, y: (str(remote_dir), True))
226+
227+
dataset = StreamingDataset(name="choco", cache_dir=tmpdir, shuffle=True)
228+
assert dataset.cache._reader._prepare_thread is None
229+
_ = dataset[0]
230+
assert dataset.cache._reader._prepare_thread
231+
dataloader = DataLoader(dataset, num_workers=1)
232+
233+
batches = []
234+
for batch in dataloader:
235+
batches.append(batch)
236+
237+
assert len(batches) == 10

0 commit comments

Comments
 (0)