@@ -70,6 +70,12 @@ def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]:
70
70
return "/" + os .path .join (* str (absolute_path ).split ("/" )[:4 ])
71
71
72
72
73
+ def _get_default_num_workers () -> int :
74
+ if torch .cuda .is_available ():
75
+ return torch .cuda .device_count ()
76
+ return os .cpu_count () or 1
77
+
78
+
73
79
class LambdaDataTransformRecipe (DataTransformRecipe ):
74
80
def __init__ (self , fn : Callable [[str , Any ], None ], inputs : Sequence [Any ]):
75
81
super ().__init__ ()
@@ -161,6 +167,7 @@ def map(
161
167
reorder_files : bool = True ,
162
168
error_when_not_empty : bool = False ,
163
169
reader : Optional [BaseReader ] = None ,
170
+ batch_size : Optional [int ] = None ,
164
171
) -> None :
165
172
"""This function map a callbable over a collection of files possibly in a distributed way.
166
173
@@ -178,6 +185,7 @@ def map(
178
185
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
179
186
Set this to ``False`` if the order in which samples are processed should be preserved.
180
187
error_when_not_empty: Whether we should error if the output folder isn't empty.
188
+ batch_size: Group the inputs into batches of batch_size length.
181
189
182
190
"""
183
191
if not isinstance (inputs , Sequence ):
@@ -212,10 +220,13 @@ def map(
212
220
213
221
input_dir = _resolve_dir (_get_input_dir (inputs ))
214
222
223
+ if isinstance (batch_size , int ) and batch_size > 1 :
224
+ inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
225
+
215
226
data_processor = DataProcessor (
216
227
input_dir = input_dir ,
217
228
output_dir = _output_dir ,
218
- num_workers = num_workers or os . cpu_count (),
229
+ num_workers = num_workers or _get_default_num_workers (),
219
230
fast_dev_run = fast_dev_run ,
220
231
num_downloaders = num_downloaders ,
221
232
num_uploaders = num_uploaders ,
@@ -247,6 +258,7 @@ def optimize(
247
258
num_uploaders : Optional [int ] = None ,
248
259
reorder_files : bool = True ,
249
260
reader : Optional [BaseReader ] = None ,
261
+ batch_size : Optional [int ] = None ,
250
262
) -> None :
251
263
"""This function converts a dataset into chunks possibly in a distributed way.
252
264
@@ -266,6 +278,7 @@ def optimize(
266
278
num_uploaders: The numbers of uploaders per worker.
267
279
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
268
280
Set this to ``False`` if the order in which samples are processed should be preserved.
281
+ batch_size: Group the inputs into batches of batch_size length.
269
282
270
283
"""
271
284
if not isinstance (inputs , Sequence ):
@@ -302,10 +315,13 @@ def optimize(
302
315
303
316
input_dir = _resolve_dir (_get_input_dir (inputs ))
304
317
318
+ if isinstance (batch_size , int ) and batch_size > 1 :
319
+ inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
320
+
305
321
data_processor = DataProcessor (
306
322
input_dir = input_dir ,
307
323
output_dir = _output_dir ,
308
- num_workers = num_workers or os . cpu_count (),
324
+ num_workers = num_workers or _get_default_num_workers (),
309
325
fast_dev_run = fast_dev_run ,
310
326
num_downloaders = num_downloaders ,
311
327
num_uploaders = num_uploaders ,
0 commit comments