Skip to content

Commit 5a0d2ef

Browse files
authored
map operator: Add support for non absolute input_dir and output_dir (#19378)
1 parent 34a34a0 commit 5a0d2ef

File tree

8 files changed

+91
-48
lines changed

8 files changed

+91
-48
lines changed

.github/workflows/ci-tests-data.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ jobs:
8787
# ls -lh $PYPI_CACHE_DIR
8888

8989
- name: Install package & dependencies
90-
timeout-minutes: 30
90+
timeout-minutes: 5
9191
run: |
9292
pip install -e ".[data-dev]" -U --prefer-binary -f ${TORCH_URL}
9393
pip list
9494
9595
- name: Testing Data
9696
working-directory: tests/tests_data
9797
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
98-
timeout-minutes: 10
98+
timeout-minutes: 25
9999
run: |
100100
python -m coverage run --source lightning \
101101
-m pytest -v --timeout=60 --durations=60

src/lightning/data/streaming/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@
5757
_NUMPY_DTYPES_MAPPING = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)}
5858

5959
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
60+
_IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None))

src/lightning/data/streaming/data_processor.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from abc import abstractmethod
1111
from dataclasses import dataclass
1212
from multiprocessing import Process, Queue
13+
from pathlib import Path
1314
from queue import Empty
1415
from time import sleep, time
1516
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
@@ -25,6 +26,7 @@
2526
_BOTO3_AVAILABLE,
2627
_DEFAULT_FAST_DEV_RUN_ITEMS,
2728
_INDEX_FILENAME,
29+
_IS_IN_STUDIO,
2830
_LIGHTNING_CLOUD_LATEST,
2931
_TORCH_GREATER_EQUAL_2_1_0,
3032
)
@@ -66,17 +68,21 @@ def _get_home_folder() -> str:
6668
return os.getenv("DATA_OPTIMIZER_HOME_FOLDER", os.path.expanduser("~"))
6769

6870

71+
def _get_default_cache() -> str:
72+
return "/cache" if _IS_IN_STUDIO else tempfile.gettempdir()
73+
74+
6975
def _get_cache_dir(name: Optional[str] = None) -> str:
7076
"""Returns the cache directory used by the Cache to store the chunks."""
71-
cache_dir = os.getenv("DATA_OPTIMIZER_CACHE_FOLDER", "/cache/chunks")
77+
cache_dir = os.getenv("DATA_OPTIMIZER_CACHE_FOLDER", f"{_get_default_cache()}/chunks")
7278
if name is None:
7379
return cache_dir
7480
return os.path.join(cache_dir, name.lstrip("/"))
7581

7682

7783
def _get_cache_data_dir(name: Optional[str] = None) -> str:
7884
"""Returns the cache data directory used by the DataProcessor workers to download the files."""
79-
cache_dir = os.getenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", "/cache/data")
85+
cache_dir = os.getenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", f"{_get_default_cache()}/data")
8086
if name is None:
8187
return os.path.join(cache_dir)
8288
return os.path.join(cache_dir, name.lstrip("/"))
@@ -222,18 +228,20 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
222228
)
223229
except Exception as e:
224230
print(e)
225-
elif output_dir.path and os.path.isdir(output_dir.path):
231+
232+
elif output_dir.path:
226233
if tmpdir is None:
227-
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
234+
output_filepath = os.path.join(output_dir.path, os.path.basename(local_filepath))
228235
else:
229236
output_filepath = os.path.join(output_dir.path, local_filepath.replace(tmpdir, "")[1:])
230-
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
231-
shutil.copyfile(local_filepath, output_filepath)
237+
238+
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
239+
shutil.move(local_filepath, output_filepath)
232240
else:
233241
raise ValueError(f"The provided {output_dir.path} isn't supported.")
234242

235243
# Inform the remover to delete the file
236-
if remove_queue:
244+
if remove_queue and os.path.exists(local_filepath):
237245
remove_queue.put([local_filepath])
238246

239247

@@ -290,7 +298,10 @@ def _get_num_bytes(item: Any, base_path: str) -> int:
290298

291299
num_bytes = 0
292300
for element in flattened_item:
293-
if isinstance(element, str) and element.startswith(base_path) and os.path.exists(element):
301+
if isinstance(element, str):
302+
element = Path(element).resolve()
303+
if not element.exists():
304+
continue
294305
file_bytes = os.path.getsize(element)
295306
if file_bytes == 0:
296307
raise RuntimeError(f"The file {element} has 0 bytes!")
@@ -475,16 +486,22 @@ def _collect_paths(self) -> None:
475486
for item in self.items:
476487
flattened_item, spec = tree_flatten(item)
477488

489+
def is_path(element: Any) -> bool:
490+
if not isinstance(element, str):
491+
return False
492+
493+
element: str = str(Path(element).resolve())
494+
return (
495+
element.startswith(self.input_dir.path)
496+
if self.input_dir.path is not None
497+
else os.path.exists(element)
498+
)
499+
478500
# For speed reasons, we assume starting with `self.input_dir` is enough to be a real file.
479501
# Other alternative would be too slow.
480502
# TODO: Try using dictionary for higher accurary.
481503
indexed_paths = {
482-
index: element
483-
for index, element in enumerate(flattened_item)
484-
if isinstance(element, str)
485-
and (
486-
element.startswith(self.input_dir.path) if self.input_dir is not None else os.path.exists(element)
487-
) # For speed reasons
504+
index: str(Path(element).resolve()) for index, element in enumerate(flattened_item) if is_path(element)
488505
}
489506

490507
if len(indexed_paths) == 0:
@@ -947,7 +964,7 @@ def run(self, data_recipe: DataRecipe) -> None:
947964
print("Workers are finished.")
948965
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)
949966

950-
if num_nodes == node_rank + 1:
967+
if num_nodes == node_rank + 1 and self.output_dir.url:
951968
_create_dataset(
952969
input_dir=self.input_dir.path,
953970
storage_dir=self.output_dir.path,

src/lightning/data/streaming/functions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]:
6565
if "/.project" in absolute_path:
6666
return "/" + os.path.join(*str(list(indexed_paths.values())[0]).split("/")[:4])
6767

68-
if indexed_paths[0] != absolute_path:
69-
raise ValueError(f"The provided path should be absolute. Found {indexed_paths[0]} instead of {absolute_path}.")
70-
7168
return "/" + os.path.join(*str(absolute_path).split("/")[:4])
7269

7370

src/lightning/data/streaming/writer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
if self._compression:
8282
if len(_COMPRESSORS) == 0:
8383
raise ValueError("No compresion algorithms are installed.")
84+
8485
if self._compression not in _COMPRESSORS:
8586
raise ValueError(
8687
f"The provided compression {self._compression} isn't available in {sorted(_COMPRESSORS)}"

src/lightning/pytorch/core/module.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -750,13 +750,11 @@ def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
750750
.. code-block:: python
751751
752752
# if you have one val dataloader:
753-
def validation_step(self, batch, batch_idx):
754-
...
753+
def validation_step(self, batch, batch_idx): ...
755754
756755
757756
# if you have multiple val dataloaders:
758-
def validation_step(self, batch, batch_idx, dataloader_idx=0):
759-
...
757+
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
760758
761759
Examples::
762760
@@ -819,13 +817,11 @@ def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
819817
.. code-block:: python
820818
821819
# if you have one test dataloader:
822-
def test_step(self, batch, batch_idx):
823-
...
820+
def test_step(self, batch, batch_idx): ...
824821
825822
826823
# if you have multiple test dataloaders:
827-
def test_step(self, batch, batch_idx, dataloader_idx=0):
828-
...
824+
def test_step(self, batch, batch_idx, dataloader_idx=0): ...
829825
830826
Examples::
831827
@@ -989,7 +985,7 @@ def configure_optimizers(self):
989985
"lr_scheduler": {
990986
"scheduler": ReduceLROnPlateau(optimizer, ...),
991987
"monitor": "metric_to_track",
992-
"frequency": "indicates how often the metric is updated"
988+
"frequency": "indicates how often the metric is updated",
993989
# If "monitor" references validation metrics, then "frequency" should be set to a
994990
# multiple of "trainer.check_val_every_n_epoch".
995991
},

tests/tests_data/streaming/test_data_processor.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33
import sys
44
from functools import partial
5+
from pathlib import Path
56
from typing import Any, List
67
from unittest import mock
78

@@ -502,7 +503,7 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
502503
"chunk-1-3.bin",
503504
]
504505

505-
assert sorted(os.listdir(cache_dir)) == fast_dev_run_disabled_chunks_0
506+
assert sorted(os.listdir(remote_output_dir)) == fast_dev_run_disabled_chunks_0
506507

507508
cache_dir = os.path.join(tmpdir, "cache_2")
508509
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
@@ -531,26 +532,11 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,
531532
"index.json",
532533
]
533534

534-
assert sorted(os.listdir(cache_dir)) == fast_dev_run_disabled_chunks_1
535-
536535
expected = sorted(fast_dev_run_disabled_chunks_0 + fast_dev_run_disabled_chunks_1 + ["1-index.json"])
537536

538537
assert sorted(os.listdir(remote_output_dir)) == expected
539538

540-
_create_dataset_mock.assert_called()
541-
542-
assert _create_dataset_mock._mock_mock_calls[0].kwargs == {
543-
"input_dir": str(input_dir),
544-
"storage_dir": str(remote_output_dir),
545-
"dataset_type": "CHUNKED",
546-
"empty": False,
547-
"size": 30,
548-
"num_bytes": 26657,
549-
"data_format": "jpeg",
550-
"compression": None,
551-
"num_chunks": 16,
552-
"num_bytes_per_chunk": [2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2],
553-
}
539+
_create_dataset_mock.assert_not_called()
554540

555541

556542
class TextTokenizeRecipe(DataChunkRecipe):
@@ -951,6 +937,36 @@ def test_data_processing_map_without_input_dir_and_folder(monkeypatch, tmpdir):
951937
assert os.path.exists(os.path.join(output_dir, "0", "0.JPEG"))
952938

953939

940+
def map_fn_map_non_absolute(path, output_dir):
941+
absolute_path = str(Path(path).absolute())
942+
assert absolute_path == path, (absolute_path, path)
943+
944+
with open(os.path.join(output_dir, os.path.basename(path)), "w") as f:
945+
f.write("Hello World")
946+
947+
948+
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
949+
def test_data_processing_map_non_absolute_path(monkeypatch, tmpdir):
950+
monkeypatch.chdir(str(tmpdir))
951+
952+
for i in range(5):
953+
with open(f"./{i}.txt", "w") as f:
954+
f.write("Hello World")
955+
956+
assert sorted(os.listdir(tmpdir)) == ["0.txt", "1.txt", "2.txt", "3.txt", "4.txt"]
957+
958+
map(
959+
map_fn_map_non_absolute,
960+
[f"{i}.txt" for i in range(5)],
961+
output_dir="./output_dir",
962+
num_workers=1,
963+
reorder_files=True,
964+
)
965+
966+
assert sorted(os.listdir(tmpdir)) == ["0.txt", "1.txt", "2.txt", "3.txt", "4.txt", "output_dir"]
967+
assert sorted(os.listdir(os.path.join(tmpdir, "output_dir"))) == ["0.txt", "1.txt", "2.txt", "3.txt", "4.txt"]
968+
969+
954970
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
955971
def test_map_error_when_not_empty(monkeypatch, tmpdir):
956972
boto3 = mock.MagicMock()
@@ -967,6 +983,8 @@ def test_map_error_when_not_empty(monkeypatch, tmpdir):
967983
error_when_not_empty=True,
968984
)
969985

986+
monkeypatch.setattr(data_processor_module, "_IS_IN_STUDIO", True)
987+
970988
with pytest.raises(OSError, match="cache"):
971989
map(
972990
map_fn,

tests/tests_data/streaming/test_writer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import pytest
2020
from lightning import seed_everything
21+
from lightning.data.streaming.compression import _ZSTD_AVAILABLE
2122
from lightning.data.streaming.reader import BinaryReader
2223
from lightning.data.streaming.sampler import ChunkedIndex
2324
from lightning.data.streaming.writer import BinaryWriter
@@ -31,7 +32,13 @@ def test_binary_writer_with_ints_and_chunk_bytes(tmpdir):
3132
with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."):
3233
BinaryWriter("dontexists", {})
3334

34-
with pytest.raises(ValueError, match="No compresion algorithms are installed."):
35+
match = (
36+
"The provided compression something_else isn't available"
37+
if _ZSTD_AVAILABLE
38+
else "No compresion algorithms are installed."
39+
)
40+
41+
with pytest.raises(ValueError, match=match):
3542
BinaryWriter(tmpdir, {"i": "int"}, compression="something_else")
3643

3744
binary_writer = BinaryWriter(tmpdir, chunk_bytes=90)
@@ -69,7 +76,13 @@ def test_binary_writer_with_ints_and_chunk_size(tmpdir):
6976
with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."):
7077
BinaryWriter("dontexists", {})
7178

72-
with pytest.raises(ValueError, match="No compresion algorithms are installed."):
79+
match = (
80+
"The provided compression something_else isn't available"
81+
if _ZSTD_AVAILABLE
82+
else "No compresion algorithms are installed."
83+
)
84+
85+
with pytest.raises(ValueError, match=match):
7386
BinaryWriter(tmpdir, {"i": "int"}, compression="something_else")
7487

7588
binary_writer = BinaryWriter(tmpdir, chunk_size=25)

0 commit comments

Comments
 (0)