Skip to content

Commit 08c9e51

Browse files
tchatonthomas
andauthored
Resolve path for StreamingDataset (#19094)
Co-authored-by: thomas <[email protected]>
1 parent bf7e54d commit 08c9e51

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

src/lightning/data/streaming/dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
111111
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
112112

113-
if self.input_dir.path is None:
113+
if _should_replace_path(self.input_dir.path):
114114
cache_path = _try_create_cache_dir(
115115
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, shard_rank=env.shard_rank
116116
)
@@ -427,6 +427,13 @@ def _collect_distributed_state_dict(state_dict: Dict[str, Any], world_size: int)
427427
return state_dict_out
428428

429429

430+
def _should_replace_path(path: str) -> bool:
431+
if path is None or path == "":
432+
return True
433+
434+
return "/datasets/" in path or "_connections/" in path
435+
436+
430437
def _is_in_dataloader_worker() -> bool:
431438
return get_worker_info() is not None
432439

tests/tests_data/streaming/test_dataset.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from lightning import seed_everything
2525
from lightning.data.streaming import Cache, functions
2626
from lightning.data.streaming.constants import _TIME_FORMAT
27-
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
27+
from lightning.data.streaming.dataset import StreamingDataset, _should_replace_path, _try_create_cache_dir
2828
from lightning.data.streaming.item_loader import TokensLoader
2929
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
3030
from lightning.data.utilities.env import _DistributedEnv
@@ -59,6 +59,14 @@ def test_streaming_dataset(tmpdir, monkeypatch):
5959
assert len(dataloader) == 6
6060

6161

62+
def test_should_replace_path():
63+
assert _should_replace_path(None)
64+
assert _should_replace_path("")
65+
assert _should_replace_path(".../datasets/...")
66+
assert _should_replace_path(".../_connections/...")
67+
assert not _should_replace_path("something_else")
68+
69+
6270
@pytest.mark.parametrize("drop_last", [False, True])
6371
def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
6472
seed_everything(42)

0 commit comments

Comments
 (0)