Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
errors.
* Adds experimental support for `get_next_index` and `set_next_index` to fetch
and advance a `grain.DatasetIterator` to the given produced element index.
* Switches to multithreading instead of multiprocessing in
`IterDataset.mp_prefetch` when free-threaded Python is detected.

* Breaking changes:
* Custom implementations of `RandomAccessDataSource` should accept `int`
Expand Down
25 changes: 24 additions & 1 deletion grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence
import functools
import json
import sys
from typing import Any, Generic, TypeVar, Union, cast, overload
import warnings

Expand Down Expand Up @@ -1334,6 +1335,9 @@ def mp_prefetch(
multiprocessing resources. We will by default run the cleanup on garbage
collection, but GC and its sequence is not guaranteed in CPython.
NOTE: In free-threaded Python builds, this implementation switches to
multithreading, ignoring ``worker_init_fn``.
Args:
options: options for the prefetching processes. ``options.num_workers``
must be greater than or equal to 0. If ``options.num_workers`` is 0,
Expand All @@ -1353,10 +1357,24 @@ def mp_prefetch(
"""

options = options or grain_options.MultiprocessingOptions(num_workers=10)
# Loaded lazily due to a circular dependency (dataset <-> process_prefetch).
# Loaded lazily due to a circular dependency (dataset <-> process_prefetch)
# and (dataset <-> prefetch).
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.dataset.transformations import process_prefetch
# pylint: enable=g-import-not-at-top
if is_in_free_threaded_python():
if worker_init_fn is not None:
warnings.warn(
"Free-threaded Python is used: `mp_prefetch` falls back to"
" thread-based implementation and `worker_init_fn` is ignored."
)
return prefetch.multithread_prefetch(
self,
num_threads=options.num_workers,
buffer_size=options.per_worker_buffer_size,
sequential_slice=sequential_slice,
)
return process_prefetch.multiprocess_prefetch(
self,
num_workers=options.num_workers,
Expand Down Expand Up @@ -1835,3 +1853,8 @@ def set_next_index(ds_iter: DatasetIterator, index: int) -> None:
def get_next_index(ds_iter: DatasetIterator) -> int:
"""Returns the next index for the dataset iterator."""
return ds_iter._get_next_index() # pylint: disable=protected-access


def is_in_free_threaded_python() -> bool:
"""Returns whether Python is running in free-threaded mode."""
return hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled() # pylint: disable=protected-access
18 changes: 18 additions & 0 deletions grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
from grain._src.python.dataset.transformations import prefetch
from grain._src.python.dataset.transformations import process_prefetch
import grain._src.python.testing.experimental as test_util
from grain.proto import execution_summary_pb2
import numpy as np
Expand Down Expand Up @@ -1183,6 +1184,23 @@ def test_apply(self, ds):
],
)

@mock.patch.object(process_prefetch, "multiprocess_prefetch", autospec=True)
@mock.patch.object(prefetch, "multithread_prefetch", autospec=True)
def test_mp_prefetch_switches_to_threads_for_free_threaded_python(
self, mock_multithread_prefetch, mock_multiprocess_prefetch
):
ds = dataset.MapDataset.range(15).to_iter_dataset()
ds.mp_prefetch()
is_free_threaded = (
hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled()
)
if is_free_threaded:
mock_multithread_prefetch.assert_called_once()
mock_multiprocess_prefetch.assert_not_called()
else:
mock_multiprocess_prefetch.assert_called_once()
mock_multithread_prefetch.assert_not_called()


class TfRandomMapAlwaysAddingOne(transforms.TfRandomMap):

Expand Down
Loading