Skip to content

Commit bb35e8e

Browse files
authored
Add batch_size to map, optimize (#19489)
1 parent bbc5488 commit bb35e8e

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/lightning/data/processing/functions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]:
7070
return "/" + os.path.join(*str(absolute_path).split("/")[:4])
7171

7272

73+
def _get_default_num_workers() -> int:
74+
if torch.cuda.is_available():
75+
return torch.cuda.device_count()
76+
return os.cpu_count() or 1
77+
78+
7379
class LambdaDataTransformRecipe(DataTransformRecipe):
7480
def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
7581
super().__init__()
@@ -161,6 +167,7 @@ def map(
161167
reorder_files: bool = True,
162168
error_when_not_empty: bool = False,
163169
reader: Optional[BaseReader] = None,
170+
batch_size: Optional[int] = None,
164171
) -> None:
165172
"""This function map a callbable over a collection of files possibly in a distributed way.
166173
@@ -178,6 +185,7 @@ def map(
178185
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
179186
Set this to ``False`` if the order in which samples are processed should be preserved.
180187
error_when_not_empty: Whether we should error if the output folder isn't empty.
188+
batch_size: Group the inputs into batches of batch_size length.
181189
182190
"""
183191
if not isinstance(inputs, Sequence):
@@ -212,10 +220,13 @@ def map(
212220

213221
input_dir = _resolve_dir(_get_input_dir(inputs))
214222

223+
if isinstance(batch_size, int) and batch_size > 1:
224+
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
225+
215226
data_processor = DataProcessor(
216227
input_dir=input_dir,
217228
output_dir=_output_dir,
218-
num_workers=num_workers or os.cpu_count(),
229+
num_workers=num_workers or _get_default_num_workers(),
219230
fast_dev_run=fast_dev_run,
220231
num_downloaders=num_downloaders,
221232
num_uploaders=num_uploaders,
@@ -247,6 +258,7 @@ def optimize(
247258
num_uploaders: Optional[int] = None,
248259
reorder_files: bool = True,
249260
reader: Optional[BaseReader] = None,
261+
batch_size: Optional[int] = None,
250262
) -> None:
251263
"""This function converts a dataset into chunks possibly in a distributed way.
252264
@@ -266,6 +278,7 @@ def optimize(
266278
num_uploaders: The numbers of uploaders per worker.
267279
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
268280
Set this to ``False`` if the order in which samples are processed should be preserved.
281+
batch_size: Group the inputs into batches of batch_size length.
269282
270283
"""
271284
if not isinstance(inputs, Sequence):
@@ -302,10 +315,13 @@ def optimize(
302315

303316
input_dir = _resolve_dir(_get_input_dir(inputs))
304317

318+
if isinstance(batch_size, int) and batch_size > 1:
319+
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
320+
305321
data_processor = DataProcessor(
306322
input_dir=input_dir,
307323
output_dir=_output_dir,
308-
num_workers=num_workers or os.cpu_count(),
324+
num_workers=num_workers or _get_default_num_workers(),
309325
fast_dev_run=fast_dev_run,
310326
num_downloaders=num_downloaders,
311327
num_uploaders=num_uploaders,

tests/tests_data/processing/test_data_processor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,25 @@ def test_map_is_last(num_workers, expected, tmpdir):
10251025
assert sorted(os.listdir(tmpdir)) == expected
10261026

10271027

1028+
def map_batch_size_fn(indexes, output_dir):
1029+
path = os.path.join(output_dir, str(indexes))
1030+
with open(path, "w") as f:
1031+
f.write("hello world")
1032+
1033+
1034+
def test_map_batch_size(tmpdir):
1035+
map(
1036+
map_batch_size_fn,
1037+
list(range(5)),
1038+
output_dir=str(tmpdir),
1039+
error_when_not_empty=False,
1040+
num_workers=1,
1041+
batch_size=2,
1042+
)
1043+
1044+
assert sorted(os.listdir(tmpdir)) == ["[0, 1]", "[2, 3]", "[4]"]
1045+
1046+
10281047
def no_op(index):
10291048
pass
10301049

0 commit comments

Comments
 (0)