Skip to content

Commit c529f85

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 837169778
1 parent 52ee8c2 commit c529f85

File tree

6 files changed

+671
-49
lines changed

6 files changed

+671
-49
lines changed

grain/_src/python/dataset/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ py_library(
5252
"//grain/_src/core:transforms",
5353
"//grain/_src/core:tree_lib",
5454
"//grain/_src/python:checkpointing",
55+
"//grain/_src/python:grain_logging",
5556
"//grain/_src/python:grain_pool",
5657
"//grain/_src/python:options",
5758
"//grain/_src/python:shared_memory_array",

grain/_src/python/dataset/transformations/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ py_test(
257257
deps = [
258258
"//grain/_src/python:options",
259259
"//grain/_src/python/dataset",
260+
"//grain/_src/python/dataset:base",
260261
"@abseil-py//absl/testing:absltest",
261262
"@abseil-py//absl/testing:parameterized",
262263
],

grain/_src/python/dataset/transformations/interleave.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections.abc import Sequence
1717
import functools
1818
from typing import TypeVar
19+
import weakref
1920

2021
from grain._src.python import options as grain_options
2122
from grain._src.python.dataset import dataset
@@ -26,20 +27,6 @@
2627
T = TypeVar("T")
2728

2829

29-
def _add_prefetch_and_make_iterator(
30-
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
31-
prefetch_buffer_size: int,
32-
) -> dataset.DatasetIterator[T]:
33-
if isinstance(ds, dataset.MapDataset):
34-
# Prefetch is automatically added in `MapDataset.__iter__`.
35-
return ds.__iter__()
36-
iterator = prefetch.ThreadPrefetchIterDataset(
37-
ds, prefetch_buffer_size=prefetch_buffer_size
38-
).__iter__()
39-
iterator.start_prefetch()
40-
return iterator
41-
42-
4330
class _InterleaveDatasetIterator(dataset.DatasetIterator[T]):
4431
"""Iterates over the interleaved datasets."""
4532

@@ -63,7 +50,14 @@ def __init__(
6350
.map(
6451
functools.partial(
6552
_add_prefetch_and_make_iterator,
66-
prefetch_buffer_size=self._iter_buffer_size,
53+
# We use weakref to avoid a circular reference. The
54+
# _InterleaveDatasetIterator holds a reference to the
55+
# prefetch iterator in `self._prefetch_ds_iter`.
56+
# The call to `_add_prefetch_and_make_iterator` (and the
57+
# partial object) would hold a reference to the
58+
# _InterleaveDatasetIterator. This would prolong its lifetime
59+
# leading to increased resource usage.
60+
interleave_iterator=weakref.ref(self),
6761
)
6862
)
6963
.to_iter_dataset(
@@ -157,6 +151,43 @@ def __str__(self) -> str:
157151
)
158152

159153

154+
def _add_prefetch_and_make_iterator(
155+
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
156+
interleave_iterator: weakref.ref[_InterleaveDatasetIterator[T]],
157+
) -> dataset.DatasetIterator[T]:
158+
"""Adds prefetching to an IterDataset and returns an iterator.
159+
160+
If the input is a MapDataset, prefetching is handled by `MapDataset.__iter__`.
161+
If the input is an IterDataset, a `ThreadPrefetchIterDataset` is used to
162+
add prefetching.
163+
164+
Args:
165+
ds: The dataset to create an iterator from.
166+
interleave_iterator: The `InterleaveDatasetIterator` instance.
167+
168+
Returns:
169+
A `dataset.DatasetIterator` for the given dataset, with prefetching
170+
enabled if applicable.
171+
172+
Raises:
173+
RuntimeError: If the interleave_iterator has been garbage collected.
174+
"""
175+
interleave_iterator_obj = interleave_iterator()
176+
if interleave_iterator_obj is None:
177+
raise RuntimeError("InterleaveDatasetIterator has been garbage collected.")
178+
if isinstance(ds, dataset.MapDataset):
179+
# Prefetch is automatically added in `MapDataset.__iter__`.
180+
return ds.__iter__()
181+
iterator = prefetch.ThreadPrefetchIterDataset(
182+
ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access
183+
).__iter__()
184+
# Propagate options applied after InterleaveIterDataset to the iterators that
185+
# are being interleaved.
186+
iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access
187+
iterator.start_prefetch()
188+
return iterator
189+
190+
160191
class InterleaveIterDataset(dataset.IterDataset[T]):
161192
"""Interleaves the given sequence of datasets.
162193

grain/_src/python/dataset/transformations/interleave_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from absl.testing import parameterized
1717
import multiprocessing as mp
1818
from grain._src.python import options
19+
from grain._src.python.dataset import base
1920
from grain._src.python.dataset import dataset
2021
from grain._src.python.dataset.transformations import interleave
2122

@@ -141,6 +142,16 @@ def test_with_mp_prefetch(self):
141142
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3))
142143
self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5])
143144

145+
def test_options_propagated(self):
146+
ds1 = dataset.MapDataset.source([1]).repeat(1000).to_iter_dataset()
147+
ds1 = ds1.filter(lambda x: False)
148+
ds2 = dataset.MapDataset.source([2]).repeat(1000).to_iter_dataset()
149+
ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=1)
150+
ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
151+
ds = dataset.WithOptionsIterDataset(ds, ds_options)
152+
with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"):
153+
list(ds)
154+
144155

145156
if __name__ == "__main__":
146157
absltest.main()

grain/_src/python/dataset/transformations/process_prefetch.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,25 @@
1515

1616
from __future__ import annotations
1717

18-
from collections.abc import Callable
18+
from collections.abc import Callable, Sequence
19+
import copy
1920
import functools
2021
from multiprocessing import queues
2122
from multiprocessing import synchronize
2223
import queue
23-
import time
2424
from typing import Any, TypeVar
2525

2626
from absl import flags
2727
import cloudpickle
2828
from grain._src.core import monitoring as grain_monitoring
2929
from grain._src.core.config import config
3030
import multiprocessing as mp
31+
from grain._src.python import grain_logging
3132
from grain._src.python import shared_memory_array
3233
from grain._src.python.dataset import base
3334
from grain._src.python.dataset import dataset
3435
from grain._src.python.dataset import stats as dataset_stats
36+
from grain._src.python.dataset.transformations import interleave
3537
from grain._src.python.dataset.transformations import prefetch
3638

3739
T = TypeVar("T")
@@ -61,6 +63,11 @@
6163
_QUEUE_WAIT_TIMEOUT_S = 1
6264

6365

66+
def _run_all(fns: Sequence[Callable[[], None]]):
67+
for fn in fns:
68+
fn()
69+
70+
6471
def _parse_debug_flags(debug_flags: dict[str, Any]):
6572
"""Parses debug flags."""
6673
flags.FLAGS["grain_py_debug_mode"].present = True
@@ -150,7 +157,6 @@ def __init__(
150157
parent: dataset.IterDataset[T],
151158
buffer_size: int,
152159
worker_init_fn: Callable[[], None] | None = None,
153-
always_report_worker_state: bool = False,
154160
):
155161
if buffer_size <= 0:
156162
raise ValueError(
@@ -159,7 +165,6 @@ def __init__(
159165
super().__init__(parent)
160166
self._buffer_size = buffer_size
161167
self._worker_init_fn = worker_init_fn
162-
self._always_report_worker_state = always_report_worker_state
163168
_validate_no_nested_process_prefetch(self._parent)
164169

165170
def __str__(self) -> str:
@@ -170,7 +175,6 @@ def __iter__(self) -> dataset.DatasetIterator[T]:
170175
self._parent,
171176
self._buffer_size,
172177
self._worker_init_fn,
173-
self._always_report_worker_state,
174178
)
175179

176180

@@ -185,7 +189,6 @@ def _put_dataset_elements_in_buffer(
185189
stats_out_queue: queues.Queue[Any] | None,
186190
start_profiling_event: synchronize.Event | None,
187191
stop_profiling_event: synchronize.Event | None,
188-
always_report_worker_state: bool,
189192
debug_flags: dict[str, Any],
190193
):
191194
"""Prefetches elements in a separate process."""
@@ -200,7 +203,6 @@ def _put_dataset_elements_in_buffer(
200203
min_shm_size = it._ctx.dataset_options.min_shm_size # pylint: disable=protected-access
201204
# Set the stats queue in worker process to send stats to the main process.
202205
it._stats._config.stats_out_queue = stats_out_queue # pylint: disable=protected-access
203-
last_recorded_state_time = time.time()
204206
parent_exhausted = False
205207
while not should_stop.is_set():
206208
if set_state_event.is_set():
@@ -217,7 +219,6 @@ def _put_dataset_elements_in_buffer(
217219
# error. Wait until set_state_event or should_stop is set.
218220
set_state_event.wait(_PARENT_EXHAUSTED_WAIT_S)
219221
continue
220-
now = time.time()
221222
try:
222223
element = it.__next__()
223224
except Exception as e: # pylint: disable=broad-except
@@ -229,14 +230,7 @@ def _put_dataset_elements_in_buffer(
229230
# __next__ method.
230231
if not it._stats._config.is_prefetch: # pylint: disable=protected-access
231232
it._stats.record_bytes_produced(element) # pylint: disable=protected-access
232-
if (
233-
always_report_worker_state
234-
or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S
235-
):
236-
last_recorded_state_time = now
237-
buffer.put((element, it.get_state(), None))
238-
else:
239-
buffer.put((element, None, None))
233+
buffer.put((element, it.get_state(), None))
240234
except Exception as e: # pylint: disable=broad-except
241235
buffer.put((None, None, e))
242236

@@ -253,13 +247,11 @@ def __init__(
253247
parent: dataset.IterDataset[T],
254248
buffer_size: int,
255249
worker_init_fn: Callable[[], None] | None = None,
256-
always_report_worker_state: bool = False,
257250
):
258251
super().__init__()
259252
self._iter_parent = parent
260253
self._buffer_size = buffer_size
261254
self._worker_init_fn = worker_init_fn
262-
self._always_report_worker_state = always_report_worker_state
263255
# Since the parent iterator is going to be created in each subprocess, and
264256
# the options are propagated during iterator creation, we need to manually
265257
# propagate them.
@@ -282,6 +274,7 @@ def __init__(
282274
self._iterations_to_skip = 0
283275
self._set_state_count = 0
284276
self._exhausted = False
277+
self._prefetch_ds_iter = None
285278

286279
# pytype: disable=attribute-error
287280
# pylint: disable=protected-access
@@ -348,7 +341,6 @@ def start_prefetch(self) -> None:
348341
stats_out_queue=self._stats_in_queue,
349342
start_profiling_event=self._start_profiling_event,
350343
stop_profiling_event=self._stop_profiling_event,
351-
always_report_worker_state=self._always_report_worker_state,
352344
debug_flags=dict(
353345
grain_py_debug_mode=config.get_or_default("py_debug_mode"),
354346
grain_py_dataset_visualization_output_dir=(
@@ -475,3 +467,73 @@ def set_state(self, state: StateT):
475467

476468
def __str__(self) -> str:
477469
return f"ProcessPrefetchDatasetIterator(buffer_size={self._buffer_size})"
470+
471+
472+
def multiprocess_prefetch(
473+
ds: dataset.IterDataset[T],
474+
num_workers: int = 0,
475+
buffer_size: int = 1,
476+
worker_init_fn: Callable[[int, int], None] | None = None,
477+
sequential_slice: bool = False,
478+
) -> dataset.IterDataset[T]:
479+
"""Uses a multiple processes to prefetch elements ahead of time.
480+
481+
It works by sharding the input dataset into `num_workers` shards, and
482+
interleaving them. Each shard is read by a separate process inside
483+
`InterleaveIterDataset`.
484+
485+
Args:
486+
ds: The parent dataset to prefetch from.
487+
num_workers: The number of processes to use for prefetching. If 0,
488+
prefetching is disabled and this is a no-op.
489+
buffer_size: The size of the prefetch buffer for each process.
490+
worker_init_fn: A function that is called in each worker process.
491+
sequential_slice: Whether to use sequential slicing.
492+
493+
Returns:
494+
`IterDataset` that prefetches elements from `ds` using multiple processes.
495+
"""
496+
if num_workers == 0:
497+
return ds
498+
499+
dataset_options = _get_dataset_options(ds)
500+
501+
shards = []
502+
for i in range(num_workers):
503+
if num_workers == 1:
504+
worker_ds = ds
505+
else:
506+
worker_ds = copy.deepcopy(ds)
507+
prefetch._set_slice_iter_dataset( # pylint: disable=protected-access
508+
worker_ds, slice(i, None, num_workers), sequential_slice
509+
)
510+
worker_ds = prefetch._MpContextIterDataset( # pylint: disable=protected-access
511+
worker_ds,
512+
base.MultiprocessingContext(
513+
process_index=i,
514+
process_count=num_workers,
515+
),
516+
)
517+
worker_index_suffix = "" if num_workers == 1 else f" {i}"
518+
519+
worker_init_fns = [
520+
functools.partial(
521+
grain_logging.set_process_identifier_prefix, worker_index_suffix
522+
)
523+
]
524+
if worker_init_fn is not None:
525+
worker_init_fns.append(functools.partial(worker_init_fn, i, num_workers))
526+
worker_ds = ProcessPrefetchIterDataset(
527+
worker_ds,
528+
buffer_size=buffer_size,
529+
worker_init_fn=functools.partial(_run_all, worker_init_fns),
530+
)
531+
shards.append(worker_ds)
532+
533+
ds = interleave.InterleaveIterDataset(
534+
shards, cycle_length=num_workers, iter_buffer_size=buffer_size
535+
)
536+
# Apply options from parent dataset because interleave dataset does not
537+
# propagate options.
538+
ds = dataset.WithOptionsIterDataset(ds, dataset_options)
539+
return ds

0 commit comments

Comments
 (0)