26
26
from lightning .data .processing .data_processor import DataChunkRecipe , DataProcessor , DataTransformRecipe
27
27
from lightning .data .processing .readers import BaseReader
28
28
from lightning .data .processing .utilities import optimize_dns_context
29
+ from lightning .data .streaming .dataloader import StreamingDataLoader
29
30
from lightning .data .streaming .resolver import (
30
31
Dir ,
31
32
_assert_dir_has_index_file ,
@@ -176,6 +177,7 @@ def map(
176
177
inputs: A sequence of input to be processed by the `fn` function.
177
178
Each input should contain at least a valid filepath.
178
179
output_dir: The folder where the processed data should be stored.
180
+ weights: Provide an associated weight to each input. This is used to balance work among workers.
179
181
num_workers: The number of workers to use during processing
180
182
fast_dev_run: Whether to use process only a sub part of the inputs
181
183
num_nodes: When doing remote execution, the number of nodes to use. Only supported on https://lightning.ai/.
@@ -188,8 +190,14 @@ def map(
188
190
batch_size: Group the inputs into batches of batch_size length.
189
191
190
192
"""
191
- if not isinstance (inputs , Sequence ):
192
- raise ValueError (f"The provided inputs should be non empty sequence. Found { inputs } ." )
193
+ if isinstance (inputs , StreamingDataLoader ) and batch_size is not None :
194
+ raise ValueError ("When providing a streaming dataloader, pass the batch_size to the dataloader directly." )
195
+
196
+ if isinstance (inputs , StreamingDataLoader ) and weights is not None :
197
+ raise ValueError ("When providing a streaming dataloader, weights isn't supported." )
198
+
199
+ if not isinstance (inputs , (Sequence , StreamingDataLoader )):
200
+ raise ValueError (f"The provided inputs should be non empty sequence or a streaming dataloader. Found { inputs } ." )
193
201
194
202
if len (inputs ) == 0 :
195
203
raise ValueError (f"The provided inputs should be non empty. Found { inputs } ." )
@@ -218,10 +226,13 @@ def map(
218
226
if error_when_not_empty :
219
227
_assert_dir_is_empty (_output_dir )
220
228
221
- input_dir = _resolve_dir (_get_input_dir (inputs ))
229
+ if not isinstance (inputs , StreamingDataLoader ):
230
+ input_dir = _resolve_dir (_get_input_dir (inputs ))
222
231
223
- if isinstance (batch_size , int ) and batch_size > 1 :
224
- inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
232
+ if isinstance (batch_size , int ) and batch_size > 1 :
233
+ inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
234
+ else :
235
+ input_dir = Dir ()
225
236
226
237
data_processor = DataProcessor (
227
238
input_dir = input_dir ,
@@ -247,6 +258,7 @@ def optimize(
247
258
fn : Callable [[Any ], Any ],
248
259
inputs : Sequence [Any ],
249
260
output_dir : str ,
261
+ weights : Optional [List [int ]] = None ,
250
262
chunk_size : Optional [int ] = None ,
251
263
chunk_bytes : Optional [Union [int , str ]] = None ,
252
264
compression : Optional [str ] = None ,
@@ -267,6 +279,7 @@ def optimize(
267
279
inputs: A sequence of input to be processed by the `fn` function.
268
280
Each input should contain at least a valid filepath.
269
281
output_dir: The folder where the processed data should be stored.
282
+ weights: Provide an associated weight to each input. This is used to balance work among workers.
270
283
chunk_size: The maximum number of elements to hold within a chunk.
271
284
chunk_bytes: The maximum number of bytes to hold within a chunk.
272
285
compression: The compression algorithm to use over the chunks.
@@ -281,8 +294,14 @@ def optimize(
281
294
batch_size: Group the inputs into batches of batch_size length.
282
295
283
296
"""
284
- if not isinstance (inputs , Sequence ):
285
- raise ValueError (f"The provided inputs should be non empty sequence. Found { inputs } ." )
297
+ if isinstance (inputs , StreamingDataLoader ) and batch_size is not None :
298
+ raise ValueError ("When providing a streaming dataloader, pass the batch_size to the dataloader directly." )
299
+
300
+ if isinstance (inputs , StreamingDataLoader ) and weights is not None :
301
+ raise ValueError ("When providing a streaming dataloader, weights isn't supported." )
302
+
303
+ if not isinstance (inputs , (Sequence , StreamingDataLoader )):
304
+ raise ValueError (f"The provided inputs should be non empty sequence or a streaming dataloader. Found { inputs } ." )
286
305
287
306
if len (inputs ) == 0 :
288
307
raise ValueError (f"The provided inputs should be non empty. Found { inputs } ." )
@@ -313,10 +332,13 @@ def optimize(
313
332
314
333
_assert_dir_has_index_file (_output_dir )
315
334
316
- input_dir = _resolve_dir (_get_input_dir (inputs ))
335
+ if not isinstance (inputs , StreamingDataLoader ):
336
+ input_dir = _resolve_dir (_get_input_dir (inputs ))
317
337
318
- if isinstance (batch_size , int ) and batch_size > 1 :
319
- inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
338
+ if isinstance (batch_size , int ) and batch_size > 1 :
339
+ inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
340
+ else :
341
+ input_dir = Dir ()
320
342
321
343
data_processor = DataProcessor (
322
344
input_dir = input_dir ,
0 commit comments