Skip to content

Commit eb0bbde

Browse files
authored
Add support for using the streaming dataloader in map or optimize for large scale inference (#19510)
1 parent 4175e1a commit eb0bbde

File tree

6 files changed

+108
-19
lines changed

6 files changed

+108
-19
lines changed

src/lightning/data/processing/data_processor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@
2828
_LIGHTNING_CLOUD_LATEST,
2929
_TORCH_GREATER_EQUAL_2_1_0,
3030
)
31-
from lightning.data.processing.readers import BaseReader
31+
from lightning.data.processing.readers import BaseReader, StreamingDataLoaderReader
3232
from lightning.data.processing.utilities import _create_dataset
3333
from lightning.data.streaming import Cache
3434
from lightning.data.streaming.cache import Dir
3535
from lightning.data.streaming.client import S3Client
36+
from lightning.data.streaming.dataloader import StreamingDataLoader
3637
from lightning.data.streaming.resolver import _resolve_dir
3738
from lightning.data.utilities.broadcast import broadcast_object
3839
from lightning.data.utilities.packing import _pack_greedily
@@ -65,11 +66,6 @@ def _get_fast_dev_run() -> int:
6566
return bool(int(os.getenv("DATA_OPTIMIZER_FAST_DEV_RUN", 1)))
6667

6768

68-
def _get_home_folder() -> str:
69-
"""Returns whether cache folder for the filepaths."""
70-
return os.getenv("DATA_OPTIMIZER_HOME_FOLDER", os.path.expanduser("~"))
71-
72-
7369
def _get_default_cache() -> str:
7470
return "/cache" if _IS_IN_STUDIO else tempfile.gettempdir()
7571

@@ -892,9 +888,12 @@ def run(self, data_recipe: DataRecipe) -> None:
892888
# Call the setup method of the user
893889
user_items: List[Any] = data_recipe.prepare_structure(self.input_dir.path if self.input_dir else None)
894890

895-
if not isinstance(user_items, list):
891+
if not isinstance(user_items, (list, StreamingDataLoader)):
896892
raise ValueError("The `prepare_structure` should return a list of item metadata.")
897893

894+
if isinstance(user_items, StreamingDataLoader):
895+
self.reader = StreamingDataLoaderReader(user_items)
896+
898897
if self.reader:
899898
user_items = self.reader.remap_items(user_items, self.num_workers)
900899

src/lightning/data/processing/functions.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
2727
from lightning.data.processing.readers import BaseReader
2828
from lightning.data.processing.utilities import optimize_dns_context
29+
from lightning.data.streaming.dataloader import StreamingDataLoader
2930
from lightning.data.streaming.resolver import (
3031
Dir,
3132
_assert_dir_has_index_file,
@@ -176,6 +177,7 @@ def map(
176177
inputs: A sequence of input to be processed by the `fn` function.
177178
Each input should contain at least a valid filepath.
178179
output_dir: The folder where the processed data should be stored.
180+
weights: Provide an associated weight to each input. This is used to balance work among workers.
179181
num_workers: The number of workers to use during processing
180182
fast_dev_run: Whether to use process only a sub part of the inputs
181183
num_nodes: When doing remote execution, the number of nodes to use. Only supported on https://lightning.ai/.
@@ -188,8 +190,14 @@ def map(
188190
batch_size: Group the inputs into batches of batch_size length.
189191
190192
"""
191-
if not isinstance(inputs, Sequence):
192-
raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.")
193+
if isinstance(inputs, StreamingDataLoader) and batch_size is not None:
194+
raise ValueError("When providing a streaming dataloader, pass the batch_size to the dataloader directly.")
195+
196+
if isinstance(inputs, StreamingDataLoader) and weights is not None:
197+
raise ValueError("When providing a streaming dataloader, weights isn't supported.")
198+
199+
if not isinstance(inputs, (Sequence, StreamingDataLoader)):
200+
raise ValueError(f"The provided inputs should be non empty sequence or a streaming dataloader. Found {inputs}.")
193201

194202
if len(inputs) == 0:
195203
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")
@@ -218,10 +226,13 @@ def map(
218226
if error_when_not_empty:
219227
_assert_dir_is_empty(_output_dir)
220228

221-
input_dir = _resolve_dir(_get_input_dir(inputs))
229+
if not isinstance(inputs, StreamingDataLoader):
230+
input_dir = _resolve_dir(_get_input_dir(inputs))
222231

223-
if isinstance(batch_size, int) and batch_size > 1:
224-
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
232+
if isinstance(batch_size, int) and batch_size > 1:
233+
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
234+
else:
235+
input_dir = Dir()
225236

226237
data_processor = DataProcessor(
227238
input_dir=input_dir,
@@ -247,6 +258,7 @@ def optimize(
247258
fn: Callable[[Any], Any],
248259
inputs: Sequence[Any],
249260
output_dir: str,
261+
weights: Optional[List[int]] = None,
250262
chunk_size: Optional[int] = None,
251263
chunk_bytes: Optional[Union[int, str]] = None,
252264
compression: Optional[str] = None,
@@ -267,6 +279,7 @@ def optimize(
267279
inputs: A sequence of input to be processed by the `fn` function.
268280
Each input should contain at least a valid filepath.
269281
output_dir: The folder where the processed data should be stored.
282+
weights: Provide an associated weight to each input. This is used to balance work among workers.
270283
chunk_size: The maximum number of elements to hold within a chunk.
271284
chunk_bytes: The maximum number of bytes to hold within a chunk.
272285
compression: The compression algorithm to use over the chunks.
@@ -281,8 +294,14 @@ def optimize(
281294
batch_size: Group the inputs into batches of batch_size length.
282295
283296
"""
284-
if not isinstance(inputs, Sequence):
285-
raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.")
297+
if isinstance(inputs, StreamingDataLoader) and batch_size is not None:
298+
raise ValueError("When providing a streaming dataloader, pass the batch_size to the dataloader directly.")
299+
300+
if isinstance(inputs, StreamingDataLoader) and weights is not None:
301+
raise ValueError("When providing a streaming dataloader, weights isn't supported.")
302+
303+
if not isinstance(inputs, (Sequence, StreamingDataLoader)):
304+
raise ValueError(f"The provided inputs should be non empty sequence or a streaming dataloader. Found {inputs}.")
286305

287306
if len(inputs) == 0:
288307
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")
@@ -313,10 +332,13 @@ def optimize(
313332

314333
_assert_dir_has_index_file(_output_dir)
315334

316-
input_dir = _resolve_dir(_get_input_dir(inputs))
335+
if not isinstance(inputs, StreamingDataLoader):
336+
input_dir = _resolve_dir(_get_input_dir(inputs))
317337

318-
if isinstance(batch_size, int) and batch_size > 1:
319-
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
338+
if isinstance(batch_size, int) and batch_size > 1:
339+
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
340+
else:
341+
input_dir = Dir()
320342

321343
data_processor = DataProcessor(
322344
input_dir=input_dir,

src/lightning/data/processing/readers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from lightning_utilities.core.imports import RequirementCache
77
from tqdm import tqdm
88

9+
from lightning.data.streaming.dataloader import StreamingDataLoader
10+
911
_PYARROW_AVAILABLE = RequirementCache("pyarrow")
1012

1113

@@ -17,7 +19,7 @@ def get_node_rank(self) -> int:
1719
return int(os.getenv("DATA_OPTIMIZER_NODE_RANK", 0))
1820

1921
@abstractmethod
20-
def remap_items(self, items: List[Any], num_workers: int) -> List[Any]:
22+
def remap_items(self, items: Any, num_workers: int) -> List[Any]:
2123
"""This method is meant to remap the items provided by the users into items more adapted to be distributed."""
2224
pass
2325

@@ -93,3 +95,18 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]:
9395
print("Finished resharding the parquet files for optimized processing.")
9496

9597
return new_items
98+
99+
100+
class StreamingDataLoaderReader(BaseReader):
101+
def __init__(self, dataloader: StreamingDataLoader) -> None:
102+
super().__init__()
103+
self.dataloader = dataloader
104+
self.dataloader_iter: Any = None
105+
106+
def read(self, _: int) -> Any:
107+
if self.dataloader_iter is None:
108+
self.dataloader_iter = iter(self.dataloader)
109+
return next(self.dataloader_iter)
110+
111+
def remap_items(self, dataloader: StreamingDataLoader, _: int) -> List[Any]:
112+
return list(range(len(dataloader)))

src/lightning/data/streaming/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ def __len__(self) -> int:
162162
return self.shuffler.get_len(self.distributed_env, self.current_epoch)
163163

164164
def __iter__(self) -> "StreamingDataset":
165+
# When the StreamingDataset is used within map or optimize, let's refetch the distributed env.
166+
if os.getenv("DATA_OPTIMIZER_GLOBAL_RANK"):
167+
self.distributed_env = _DistributedEnv.detect()
168+
165169
self.worker_env = _WorkerEnv.detect()
166170
self.cache = self._create_cache(worker_env=self.worker_env)
167171
self.shuffler = self._create_shuffler(self.cache)

src/lightning/data/utilities/env.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Callable, Optional
23

34
import torch
@@ -28,6 +29,9 @@ def detect(cls) -> "_DistributedEnv":
2829
It will default to 1 distributed process in this case.
2930
3031
"""
32+
if _is_in_map_or_optimize():
33+
return cls._instantiate_in_map_or_optimize()
34+
3135
if torch.distributed.is_available() and torch.distributed.is_initialized():
3236
world_size = torch.distributed.get_world_size()
3337
global_rank = torch.distributed.get_rank()
@@ -45,6 +49,13 @@ def detect(cls) -> "_DistributedEnv":
4549

4650
return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)
4751

52+
@classmethod
53+
def _instantiate_in_map_or_optimize(cls) -> "_DistributedEnv":
54+
global_rank = int(os.getenv("DATA_OPTIMIZER_GLOBAL_RANK", "0"))
55+
num_workers = int(os.getenv("DATA_OPTIMIZER_NUM_WORKERS", "0"))
56+
num_nodes = int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1))
57+
return cls(world_size=num_workers * num_nodes, global_rank=int(global_rank), num_nodes=num_nodes)
58+
4859
def __repr__(self) -> str:
4960
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"
5061

@@ -165,3 +176,7 @@ def __str__(self) -> str:
165176

166177
def _is_in_dataloader_worker() -> bool:
167178
return torch_get_worker_info() is not None
179+
180+
181+
def _is_in_map_or_optimize() -> bool:
182+
return os.getenv("DATA_OPTIMIZER_GLOBAL_RANK") is not None

tests/tests_data/processing/test_data_processor.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
_wait_for_file_to_exist,
3030
)
3131
from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize
32-
from lightning.data.streaming import StreamingDataset, resolver
32+
from lightning.data.streaming import StreamingDataLoader, StreamingDataset, resolver
3333
from lightning.data.streaming.cache import Cache, Dir
3434
from lightning_utilities.core.imports import RequirementCache
3535

@@ -1158,3 +1158,35 @@ def test_to_path(tmpdir):
11581158

11591159
assert _to_path("/teamspace/studios/this_studio/a.png") == "/teamspace/studios/this_studio/a.png"
11601160
assert _to_path(filepath) == filepath
1161+
1162+
1163+
def fetch_from_dataset(batch, output_dir):
1164+
for index in batch.numpy().tolist():
1165+
filepath = os.path.join(output_dir, f"{index}.txt")
1166+
with open(filepath, "w") as f:
1167+
f.write("Hello World!")
1168+
1169+
1170+
@pytest.mark.skipif(sys.platform == "win32", reason="skip windows")
1171+
def test_streaming_dataset_in_map(tmpdir):
1172+
seed_everything(42)
1173+
1174+
output_dir = os.path.join(tmpdir, "output_dir")
1175+
1176+
cache = Cache(input_dir=str(tmpdir), chunk_size=10)
1177+
for i in range(107):
1178+
cache[i] = i
1179+
1180+
cache.done()
1181+
cache.merge()
1182+
1183+
dataset = StreamingDataset(input_dir=str(tmpdir))
1184+
1185+
map(
1186+
fn=fetch_from_dataset,
1187+
inputs=StreamingDataLoader(dataset, num_workers=1, batch_size=2),
1188+
output_dir=output_dir,
1189+
num_workers=2,
1190+
)
1191+
1192+
assert sorted(os.listdir(output_dir)) == sorted([f"{i}.txt" for i in range(107)])

0 commit comments

Comments
 (0)