Skip to content

Commit 8280519

Browse files
authored
Data Processor: Add is_last argument to know when the last item for the current worker is being processed (#19383)
1 parent 5a0d2ef commit 8280519

File tree

3 files changed

+46
-43
lines changed

3 files changed

+46
-43
lines changed

src/lightning/data/streaming/data_processor.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def _handle_data_chunk_recipe_end(self) -> None:
608608
def _handle_data_transform_recipe(self, index: int) -> None:
609609
# Don't use a context manager to avoid deleting files that are being uploaded.
610610
output_dir = tempfile.mkdtemp()
611-
item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir))
611+
item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir), len(self.items) - 1 == index)
612612
if item_data is not None:
613613
raise ValueError(
614614
"When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything."
@@ -649,33 +649,9 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
649649
pass
650650

651651
@abstractmethod
652-
def prepare_item(self, *args: Any) -> Any:
652+
def prepare_item(self, *args: Any, **kwargs: Any) -> Any:
653653
pass
654654

655-
def listdir(self, path: str) -> List[str]:
656-
home = _get_home_folder()
657-
filepath = os.path.join(home, ".cache", f"{self._name}/filepaths.txt")
658-
659-
if os.path.exists(filepath):
660-
lines = []
661-
with open(filepath) as f:
662-
for line in f.readlines():
663-
lines.append(line.replace("\n", ""))
664-
return lines
665-
666-
filepaths = []
667-
for dirpath, _, filenames in os.walk(path):
668-
for filename in filenames:
669-
filepaths.append(os.path.join(dirpath, filename))
670-
671-
os.makedirs(os.path.dirname(filepath), exist_ok=True)
672-
673-
with open(filepath, "w") as f:
674-
for filepath in filepaths:
675-
f.write(f"{filepath}\n")
676-
677-
return filepaths
678-
679655
def __init__(self) -> None:
680656
self._name: Optional[str] = None
681657

@@ -707,7 +683,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
707683
"""
708684

709685
@abstractmethod
710-
def prepare_item(self, item_metadata: T) -> Any: # type: ignore
686+
def prepare_item(self, item_metadata: T) -> Any:
711687
"""The return of this `prepare_item` method is persisted in chunked binary files."""
712688

713689
def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result:
@@ -798,7 +774,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
798774
"""
799775

800776
@abstractmethod
801-
def prepare_item(self, item_metadata: T, output_dir: str) -> None: # type: ignore
777+
def prepare_item(self, item_metadata: T, output_dir: str, is_last: bool) -> None:
802778
"""Use your item metadata to process your files and save the file outputs into `output_dir`."""
803779

804780

src/lightning/data/streaming/functions.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,28 @@ def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
7878
_fn = self._fn if isinstance(self._fn, FunctionType) else self._fn.__call__ # type: ignore
7979
params = inspect.signature(_fn).parameters
8080
self._contains_device = "device" in params
81+
self._contains_is_last = "is_last" in params
8182

8283
def prepare_structure(self, _: Optional[str]) -> Any:
8384
return self._inputs
8485

85-
def prepare_item(self, item_metadata: Any, output_dir: str) -> None: # type: ignore
86+
def prepare_item(self, item_metadata: Any, output_dir: str, is_last: bool) -> None:
8687
if self._contains_device and self._device is None:
8788
self._find_device()
8889

90+
kwargs: Dict[str, Any] = {}
91+
92+
if self._contains_device:
93+
kwargs["device"] = self._device
94+
95+
if self._contains_is_last:
96+
kwargs["is_last"] = is_last
97+
8998
if isinstance(self._fn, (FunctionType, partial)):
90-
if self._contains_device:
91-
self._fn(item_metadata, output_dir, self._device)
92-
else:
93-
self._fn(item_metadata, output_dir)
99+
self._fn(item_metadata, output_dir, **kwargs)
94100

95101
elif callable(self._fn):
96-
if self._contains_device:
97-
self._fn.__call__(item_metadata, output_dir, self._device) # type: ignore
98-
else:
99-
self._fn.__call__(item_metadata, output_dir) # type: ignore
102+
self._fn.__call__(item_metadata, output_dir, **kwargs) # type: ignore
100103
else:
101104
raise ValueError(f"The provided {self._fn} isn't supported.")
102105

@@ -124,7 +127,7 @@ def __init__(
124127
def prepare_structure(self, input_dir: Optional[str]) -> Any:
125128
return self._inputs
126129

127-
def prepare_item(self, item_metadata: Any) -> Any: # type: ignore
130+
def prepare_item(self, item_metadata: Any) -> Any:
128131
if isinstance(self._fn, partial):
129132
yield from self._fn(item_metadata)
130133

tests/tests_data/streaming/test_data_processor.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def test_map_items_to_workers_sequentially(monkeypatch):
350350

351351
class CustomDataChunkRecipe(DataChunkRecipe):
352352
def prepare_structure(self, input_dir: str) -> List[Any]:
353-
filepaths = self.listdir(input_dir)
353+
filepaths = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]
354354
assert len(filepaths) == 30
355355
return filepaths
356356

@@ -567,7 +567,7 @@ def prepare_structure(self, input_dir: str):
567567
filepaths = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
568568
return [filepath for filepath in filepaths if os.path.isfile(filepath)]
569569

570-
def prepare_item(self, filepath: Any, output_dir: str) -> None:
570+
def prepare_item(self, filepath: Any, output_dir: str, is_last) -> None:
571571
from PIL import Image
572572

573573
img = Image.open(filepath)
@@ -819,7 +819,7 @@ def fn(output_dir, item, device):
819819

820820
data_recipe = LambdaDataTransformRecipe(fn, range(1))
821821

822-
data_recipe.prepare_item(1, "")
822+
data_recipe.prepare_item(1, "", False)
823823
assert called
824824

825825

@@ -839,7 +839,7 @@ def __call__(self, item, output_dir, device):
839839
called = True
840840

841841
data_recipe = LambdaDataTransformRecipe(Transform(), range(1))
842-
data_recipe.prepare_item(1, "")
842+
data_recipe.prepare_item(1, "", False)
843843
assert called
844844

845845

@@ -968,7 +968,7 @@ def test_data_processing_map_non_absolute_path(monkeypatch, tmpdir):
968968

969969

970970
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
971-
def test_map_error_when_not_empty(monkeypatch, tmpdir):
971+
def test_map_error_when_not_empty(monkeypatch):
972972
boto3 = mock.MagicMock()
973973
client_s3_mock = mock.MagicMock()
974974
client_s3_mock.list_objects_v2.return_value = {"KeyCount": 1, "Contents": []}
@@ -992,3 +992,27 @@ def test_map_error_when_not_empty(monkeypatch, tmpdir):
992992
output_dir=Dir(path=None, url="s3://bucket"),
993993
error_when_not_empty=False,
994994
)
995+
996+
def map_fn_is_last(index, output_dir, is_last):
997+
with open(os.path.join(output_dir, f"{index}_{is_last}.txt"), "w") as f:
998+
f.write("here")
999+
1000+
1001+
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
1002+
@pytest.mark.parametrize(
1003+
("num_workers", "expected"),
1004+
[
1005+
(1, ['0_False.txt', '1_False.txt', '2_False.txt', '3_False.txt', '4_True.txt']),
1006+
(2, ['0_False.txt', '1_True.txt', '2_False.txt', '3_False.txt', '4_True.txt']),
1007+
],
1008+
)
1009+
def test_map_is_last(num_workers, expected, tmpdir):
1010+
map(
1011+
map_fn_is_last,
1012+
list(range(5)),
1013+
output_dir=str(tmpdir),
1014+
error_when_not_empty=False,
1015+
num_workers=num_workers,
1016+
)
1017+
1018+
assert sorted(os.listdir(tmpdir)) == expected

0 commit comments

Comments
 (0)