Skip to content

Commit 672e730

Browse files
tchatonthomas
authored andcommitted
Add GPU support for map (#18947)
Co-authored-by: thomas <[email protected]> (cherry picked from commit 97c730e)
1 parent 6df6162 commit 672e730

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

src/lightning/data/streaming/functions.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from types import FunctionType
1919
from typing import Any, Callable, Optional, Sequence, Union
2020

21+
import torch
22+
2123
from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50, _TORCH_GREATER_EQUAL_2_1_0
2224
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
2325

@@ -53,12 +55,37 @@ def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
5355
super().__init__()
5456
self._fn = fn
5557
self._inputs = inputs
58+
self._device: Optional[str] = None
59+
60+
_fn = self._fn if isinstance(self._fn, FunctionType) else self._fn.__call__ # type: ignore
61+
params = inspect.signature(_fn).parameters
62+
self._contains_device = "device" in params
5663

5764
def prepare_structure(self, input_dir: Optional[str]) -> Any:
5865
return self._inputs
5966

6067
def prepare_item(self, output_dir: str, item_metadata: Any) -> None: # type: ignore
61-
self._fn(output_dir, item_metadata)
68+
if self._contains_device and self._device is None:
69+
self._find_device()
70+
if isinstance(self._fn, FunctionType):
71+
if self._contains_device:
72+
self._fn(output_dir, item_metadata, self._device)
73+
else:
74+
self._fn(output_dir, item_metadata)
75+
elif callable(self._fn):
76+
if self._contains_device:
77+
self._fn.__call__(output_dir, item_metadata, self._device) # type: ignore
78+
else:
79+
self._fn.__call__(output_dir, item_metadata) # type: ignore
80+
else:
81+
raise ValueError(f"The provided {self._fn} isn't supported.")
82+
83+
def _find_device(self) -> None:
84+
global_rank = os.getenv("DATA_OPTIMIZER_GLOBAL_RANK", None)
85+
if torch.cuda.is_available() and global_rank:
86+
num_gpus = torch.cuda.device_count()
87+
device = int(global_rank) % num_gpus
88+
self._device = f"cuda:{device}"
6289

6390

6491
class LambdaDataChunkRecipe(DataChunkRecipe):

tests/tests_data/streaming/test_data_processor.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
_upload_fn,
2121
_wait_for_file_to_exist,
2222
)
23-
from lightning.data.streaming.functions import map, optimize
23+
from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize
2424
from lightning_utilities.core.imports import RequirementCache
2525

2626
_PIL_AVAILABLE = RequirementCache("PIL")
@@ -766,3 +766,44 @@ def test_data_processing_optimize_class_yield(monkeypatch, tmpdir):
766766

767767
cache = Cache(output_dir, chunk_size=1)
768768
assert len(cache) == 5
769+
770+
771+
def test_lambda_transform_recipe(monkeypatch):
772+
torch_mock = mock.MagicMock()
773+
torch_mock.cuda.device_count.return_value = 3
774+
775+
monkeypatch.setattr(functions, "torch", torch_mock)
776+
monkeypatch.setenv("DATA_OPTIMIZER_GLOBAL_RANK", 2)
777+
778+
called = False
779+
780+
def fn(output_dir, item, device):
781+
nonlocal called
782+
assert device == "cuda:2"
783+
called = True
784+
785+
data_recipe = LambdaDataTransformRecipe(fn, range(1))
786+
787+
data_recipe.prepare_item("", 1)
788+
assert called
789+
790+
791+
def test_lambda_transform_recipe_class(monkeypatch):
792+
torch_mock = mock.MagicMock()
793+
torch_mock.cuda.device_count.return_value = 3
794+
795+
monkeypatch.setattr(functions, "torch", torch_mock)
796+
monkeypatch.setenv("DATA_OPTIMIZER_GLOBAL_RANK", 2)
797+
798+
called = False
799+
800+
class Transform:
801+
def __call__(self, output_dir, item, device):
802+
nonlocal called
803+
assert device == "cuda:2"
804+
called = True
805+
806+
data_recipe = LambdaDataTransformRecipe(Transform(), range(1))
807+
808+
data_recipe.prepare_item("", 1)
809+
assert called

0 commit comments

Comments
 (0)