|
18 | 18 | from types import FunctionType
|
19 | 19 | from typing import Any, Callable, Optional, Sequence, Union
|
20 | 20 |
|
| 21 | +import torch |
| 22 | + |
21 | 23 | from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50, _TORCH_GREATER_EQUAL_2_1_0
|
22 | 24 | from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
|
23 | 25 |
|
@@ -53,12 +55,37 @@ def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
|
53 | 55 | super().__init__()
|
54 | 56 | self._fn = fn
|
55 | 57 | self._inputs = inputs
|
| 58 | + self._device: Optional[str] = None |
| 59 | + |
| 60 | + _fn = self._fn if isinstance(self._fn, FunctionType) else self._fn.__call__ # type: ignore |
| 61 | + params = inspect.signature(_fn).parameters |
| 62 | + self._contains_device = "device" in params |
56 | 63 |
|
57 | 64 | def prepare_structure(self, input_dir: Optional[str]) -> Any:
|
58 | 65 | return self._inputs
|
59 | 66 |
|
60 | 67 | def prepare_item(self, output_dir: str, item_metadata: Any) -> None: # type: ignore
|
61 |
| - self._fn(output_dir, item_metadata) |
| 68 | + if self._contains_device and self._device is None: |
| 69 | + self._find_device() |
| 70 | + if isinstance(self._fn, FunctionType): |
| 71 | + if self._contains_device: |
| 72 | + self._fn(output_dir, item_metadata, self._device) |
| 73 | + else: |
| 74 | + self._fn(output_dir, item_metadata) |
| 75 | + elif callable(self._fn): |
| 76 | + if self._contains_device: |
| 77 | + self._fn.__call__(output_dir, item_metadata, self._device) # type: ignore |
| 78 | + else: |
| 79 | + self._fn.__call__(output_dir, item_metadata) # type: ignore |
| 80 | + else: |
| 81 | + raise ValueError(f"The provided {self._fn} isn't supported.") |
| 82 | + |
| 83 | + def _find_device(self) -> None: |
| 84 | + global_rank = os.getenv("DATA_OPTIMIZER_GLOBAL_RANK", None) |
| 85 | + if torch.cuda.is_available() and global_rank: |
| 86 | + num_gpus = torch.cuda.device_count() |
| 87 | + device = int(global_rank) % num_gpus |
| 88 | + self._device = f"cuda:{device}" |
62 | 89 |
|
63 | 90 |
|
64 | 91 | class LambdaDataChunkRecipe(DataChunkRecipe):
|
|
0 commit comments