Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ py_library(
"//grain/_src/core:transforms",
"//grain/_src/core:tree_lib",
"//grain/_src/python:checkpointing",
"//grain/_src/python:grain_logging",
"//grain/_src/python:grain_pool",
"//grain/_src/python:options",
"//grain/_src/python:shared_memory_array",
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ py_test(
deps = [
"//grain/_src/python:options",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
],
Expand Down
61 changes: 46 additions & 15 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Sequence
import functools
from typing import TypeVar
import weakref

from grain._src.python import options as grain_options
from grain._src.python.dataset import dataset
Expand All @@ -26,20 +27,6 @@
T = TypeVar("T")


def _add_prefetch_and_make_iterator(
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
prefetch_buffer_size: int,
) -> dataset.DatasetIterator[T]:
if isinstance(ds, dataset.MapDataset):
# Prefetch is automatically added in `MapDataset.__iter__`.
return ds.__iter__()
iterator = prefetch.ThreadPrefetchIterDataset(
ds, prefetch_buffer_size=prefetch_buffer_size
).__iter__()
iterator.start_prefetch()
return iterator


class _InterleaveDatasetIterator(dataset.DatasetIterator[T]):
"""Iterates over the interleaved datasets."""

Expand All @@ -63,7 +50,14 @@ def __init__(
.map(
functools.partial(
_add_prefetch_and_make_iterator,
prefetch_buffer_size=self._iter_buffer_size,
# We use weakref to avoid a circular reference. The
# _InterleaveDatasetIterator holds a reference to the
# prefetch iterator in `self._prefetch_ds_iter`.
# The call to `_add_prefetch_and_make_iterator` (and the
# partial object) would hold a reference to the
# _InterleaveDatasetIterator. This would prolong its lifetime
# leading to increased resource usage.
interleave_iterator=weakref.ref(self),
)
)
.to_iter_dataset(
Expand Down Expand Up @@ -157,6 +151,43 @@ def __str__(self) -> str:
)


def _add_prefetch_and_make_iterator(
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
interleave_iterator: weakref.ref[_InterleaveDatasetIterator[T]],
) -> dataset.DatasetIterator[T]:
"""Adds prefetching to an IterDataset and returns an iterator.

If the input is a MapDataset, prefetching is handled by `MapDataset.__iter__`.
If the input is an IterDataset, a `ThreadPrefetchIterDataset` is used to
add prefetching.

Args:
ds: The dataset to create an iterator from.
interleave_iterator: The `InterleaveDatasetIterator` instance.

Returns:
A `dataset.DatasetIterator` for the given dataset, with prefetching
enabled if applicable.

Raises:
RuntimeError: If the interleave_iterator has been garbage collected.
"""
interleave_iterator_obj = interleave_iterator()
if interleave_iterator_obj is None:
raise RuntimeError("InterleaveDatasetIterator has been garbage collected.")
if isinstance(ds, dataset.MapDataset):
# Prefetch is automatically added in `MapDataset.__iter__`.
return ds.__iter__()
iterator = prefetch.ThreadPrefetchIterDataset(
ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access
).__iter__()
# Propagate options applied after InterleaveIterDataset to the iterators that
# are being interleaved.
iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access
iterator.start_prefetch()
return iterator


class InterleaveIterDataset(dataset.IterDataset[T]):
"""Interleaves the given sequence of datasets.

Expand Down
11 changes: 11 additions & 0 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from absl.testing import parameterized
import multiprocessing as mp
from grain._src.python import options
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import interleave

Expand Down Expand Up @@ -141,6 +142,16 @@ def test_with_mp_prefetch(self):
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3))
self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5])

def test_options_propagated(self):
ds1 = dataset.MapDataset.source([1]).repeat(1000).to_iter_dataset()
ds1 = ds1.filter(lambda x: False)
ds2 = dataset.MapDataset.source([2]).repeat(1000).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=1)
ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
ds = dataset.WithOptionsIterDataset(ds, ds_options)
with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"):
list(ds)


if __name__ == "__main__":
absltest.main()
100 changes: 81 additions & 19 deletions grain/_src/python/dataset/transformations/process_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,25 @@

from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Sequence
import copy
import functools
from multiprocessing import queues
from multiprocessing import synchronize
import queue
import time
from typing import Any, TypeVar

from absl import flags
import cloudpickle
from grain._src.core import monitoring as grain_monitoring
from grain._src.core.config import config
import multiprocessing as mp
from grain._src.python import grain_logging
from grain._src.python import shared_memory_array
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
from grain._src.python.dataset.transformations import interleave
from grain._src.python.dataset.transformations import prefetch

T = TypeVar("T")
Expand Down Expand Up @@ -61,6 +63,11 @@
_QUEUE_WAIT_TIMEOUT_S = 1


def _run_all(fns: Sequence[Callable[[], None]]):
for fn in fns:
fn()


def _parse_debug_flags(debug_flags: dict[str, Any]):
"""Parses debug flags."""
flags.FLAGS["grain_py_debug_mode"].present = True
Expand Down Expand Up @@ -150,7 +157,6 @@ def __init__(
parent: dataset.IterDataset[T],
buffer_size: int,
worker_init_fn: Callable[[], None] | None = None,
always_report_worker_state: bool = False,
):
if buffer_size <= 0:
raise ValueError(
Expand All @@ -159,7 +165,6 @@ def __init__(
super().__init__(parent)
self._buffer_size = buffer_size
self._worker_init_fn = worker_init_fn
self._always_report_worker_state = always_report_worker_state
_validate_no_nested_process_prefetch(self._parent)

def __str__(self) -> str:
Expand All @@ -170,7 +175,6 @@ def __iter__(self) -> dataset.DatasetIterator[T]:
self._parent,
self._buffer_size,
self._worker_init_fn,
self._always_report_worker_state,
)


Expand All @@ -185,7 +189,6 @@ def _put_dataset_elements_in_buffer(
stats_out_queue: queues.Queue[Any] | None,
start_profiling_event: synchronize.Event | None,
stop_profiling_event: synchronize.Event | None,
always_report_worker_state: bool,
debug_flags: dict[str, Any],
):
"""Prefetches elements in a separate process."""
Expand All @@ -200,7 +203,6 @@ def _put_dataset_elements_in_buffer(
min_shm_size = it._ctx.dataset_options.min_shm_size # pylint: disable=protected-access
# Set the stats queue in worker process to send stats to the main process.
it._stats._config.stats_out_queue = stats_out_queue # pylint: disable=protected-access
last_recorded_state_time = time.time()
parent_exhausted = False
while not should_stop.is_set():
if set_state_event.is_set():
Expand All @@ -217,7 +219,6 @@ def _put_dataset_elements_in_buffer(
# error. Wait until set_state_event or should_stop is set.
set_state_event.wait(_PARENT_EXHAUSTED_WAIT_S)
continue
now = time.time()
try:
element = it.__next__()
except Exception as e: # pylint: disable=broad-except
Expand All @@ -229,14 +230,7 @@ def _put_dataset_elements_in_buffer(
# __next__ method.
if not it._stats._config.is_prefetch: # pylint: disable=protected-access
it._stats.record_bytes_produced(element) # pylint: disable=protected-access
if (
always_report_worker_state
or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S
):
last_recorded_state_time = now
buffer.put((element, it.get_state(), None))
else:
buffer.put((element, None, None))
buffer.put((element, it.get_state(), None))
except Exception as e: # pylint: disable=broad-except
buffer.put((None, None, e))

Expand All @@ -253,13 +247,11 @@ def __init__(
parent: dataset.IterDataset[T],
buffer_size: int,
worker_init_fn: Callable[[], None] | None = None,
always_report_worker_state: bool = False,
):
super().__init__()
self._iter_parent = parent
self._buffer_size = buffer_size
self._worker_init_fn = worker_init_fn
self._always_report_worker_state = always_report_worker_state
# Since the parent iterator is going to be created in each subprocess, and
# the options are propagated during iterator creation, we need to manually
# propagate them.
Expand All @@ -282,6 +274,7 @@ def __init__(
self._iterations_to_skip = 0
self._set_state_count = 0
self._exhausted = False
self._prefetch_ds_iter = None

# pytype: disable=attribute-error
# pylint: disable=protected-access
Expand Down Expand Up @@ -348,7 +341,6 @@ def start_prefetch(self) -> None:
stats_out_queue=self._stats_in_queue,
start_profiling_event=self._start_profiling_event,
stop_profiling_event=self._stop_profiling_event,
always_report_worker_state=self._always_report_worker_state,
debug_flags=dict(
grain_py_debug_mode=config.get_or_default("py_debug_mode"),
grain_py_dataset_visualization_output_dir=(
Expand Down Expand Up @@ -475,3 +467,73 @@ def set_state(self, state: StateT):

def __str__(self) -> str:
return f"ProcessPrefetchDatasetIterator(buffer_size={self._buffer_size})"


def multiprocess_prefetch(
ds: dataset.IterDataset[T],
num_workers: int = 0,
buffer_size: int = 1,
worker_init_fn: Callable[[int, int], None] | None = None,
sequential_slice: bool = False,
) -> dataset.IterDataset[T]:
"""Uses a multiple processes to prefetch elements ahead of time.

It works by sharding the input dataset into `num_workers` shards, and
interleaving them. Each shard is read by a separate process inside
`InterleaveIterDataset`.

Args:
ds: The parent dataset to prefetch from.
num_workers: The number of processes to use for prefetching. If 0,
prefetching is disabled and this is a no-op.
buffer_size: The size of the prefetch buffer for each process.
worker_init_fn: A function that is called in each worker process.
sequential_slice: Whether to use sequential slicing.

Returns:
`IterDataset` that prefetches elements from `ds` using multiple processes.
"""
if num_workers == 0:
return ds

dataset_options = _get_dataset_options(ds)

shards = []
for i in range(num_workers):
if num_workers == 1:
worker_ds = ds
else:
worker_ds = copy.deepcopy(ds)
prefetch._set_slice_iter_dataset( # pylint: disable=protected-access
worker_ds, slice(i, None, num_workers), sequential_slice
)
worker_ds = prefetch._MpContextIterDataset( # pylint: disable=protected-access
worker_ds,
base.MultiprocessingContext(
process_index=i,
process_count=num_workers,
),
)
worker_index_suffix = "" if num_workers == 1 else f" {i}"

worker_init_fns = [
functools.partial(
grain_logging.set_process_identifier_prefix, worker_index_suffix
)
]
if worker_init_fn is not None:
worker_init_fns.append(functools.partial(worker_init_fn, i, num_workers))
worker_ds = ProcessPrefetchIterDataset(
worker_ds,
buffer_size=buffer_size,
worker_init_fn=functools.partial(_run_all, worker_init_fns),
)
shards.append(worker_ds)

ds = interleave.InterleaveIterDataset(
shards, cycle_length=num_workers, iter_buffer_size=buffer_size
)
# Apply options from parent dataset because interleave dataset does not
# propagate options.
ds = dataset.WithOptionsIterDataset(ds, dataset_options)
return ds
Loading
Loading