Skip to content

Commit d47a7f0

Browse files
Nithin Tatikondacopybara-github
authored andcommitted
mp_prefetch uses multithreading when the GIL is disabled.
PiperOrigin-RevId: 868188468
1 parent a4f73e0 commit d47a7f0

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
1010
errors.
1111
* Adds experimental support for `get_next_index` and `set_next_index` to fetch
1212
and advance a `grain.DatasetIterator` to the given produced element index.
13+
* Switches to multithreading instead of multiprocessing in
14+
`IterDataset.mp_prefetch` when free-threaded Python is detected.
1315

1416
* Breaking changes:
1517
* Custom implementations of `RandomAccessDataSource` should accept `int`

grain/_src/python/dataset/dataset.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence
5151
import functools
5252
import json
53+
import sys
5354
from typing import Any, Generic, TypeVar, Union, cast, overload
5455
import warnings
5556

@@ -1334,6 +1335,9 @@ def mp_prefetch(
13341335
multiprocessing resources. We will by default run the cleanup on garbage
13351336
collection, but GC and its sequence is not guaranteed in CPython.
13361337
1338+
NOTE: In free-threaded Python builds, this implementation switches to
1339+
multithreading, ignoring ``worker_init_fn``.
1340+
13371341
Args:
13381342
options: options for the prefetching processes. ``options.num_workers``
13391343
must be greater than or equal to 0. If ``options.num_workers`` is 0,
@@ -1353,10 +1357,24 @@ def mp_prefetch(
13531357
"""
13541358

13551359
options = options or grain_options.MultiprocessingOptions(num_workers=10)
1356-
# Loaded lazily due to a circular dependency (dataset <-> process_prefetch).
1360+
# Loaded lazily due to a circular dependency (dataset <-> process_prefetch)
1361+
# and (dataset <-> prefetch).
13571362
# pylint: disable=g-import-not-at-top
1363+
from grain._src.python.dataset.transformations import prefetch
13581364
from grain._src.python.dataset.transformations import process_prefetch
13591365
# pylint: enable=g-import-not-at-top
1366+
if is_in_free_threaded_python():
1367+
if worker_init_fn is not None:
1368+
warnings.warn(
1369+
"Free-threaded Python is used: `mp_prefetch` falls back to"
1370+
" thread-based implementation and `worker_init_fn` is ignored."
1371+
)
1372+
return prefetch.multithread_prefetch(
1373+
self,
1374+
num_threads=options.num_workers,
1375+
buffer_size=options.per_worker_buffer_size,
1376+
sequential_slice=sequential_slice,
1377+
)
13601378
return process_prefetch.multiprocess_prefetch(
13611379
self,
13621380
num_workers=options.num_workers,
@@ -1835,3 +1853,8 @@ def set_next_index(ds_iter: DatasetIterator, index: int) -> None:
18351853
def get_next_index(ds_iter: DatasetIterator) -> int:
18361854
"""Returns the next index for the dataset iterator."""
18371855
return ds_iter._get_next_index() # pylint: disable=protected-access
1856+
1857+
1858+
def is_in_free_threaded_python() -> bool:
1859+
"""Returns whether Python is running in free-threaded mode."""
1860+
return hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled() # pylint: disable=protected-access

grain/_src/python/dataset/dataset_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from grain._src.python.dataset import dataset
3333
from grain._src.python.dataset import stats as dataset_stats
3434
from grain._src.python.dataset.transformations import prefetch
35+
from grain._src.python.dataset.transformations import process_prefetch
3536
import grain._src.python.testing.experimental as test_util
3637
from grain.proto import execution_summary_pb2
3738
import numpy as np
@@ -1183,6 +1184,23 @@ def test_apply(self, ds):
11831184
],
11841185
)
11851186

1187+
@mock.patch.object(process_prefetch, "multiprocess_prefetch", autospec=True)
1188+
@mock.patch.object(prefetch, "multithread_prefetch", autospec=True)
1189+
def test_mp_prefetch_switches_to_threads_for_free_threaded_python(
1190+
self, mock_multithread_prefetch, mock_multiprocess_prefetch
1191+
):
1192+
ds = dataset.MapDataset.range(15).to_iter_dataset()
1193+
ds.mp_prefetch()
1194+
is_free_threaded = (
1195+
hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled()
1196+
)
1197+
if is_free_threaded:
1198+
mock_multithread_prefetch.assert_called_once()
1199+
mock_multiprocess_prefetch.assert_not_called()
1200+
else:
1201+
mock_multiprocess_prefetch.assert_called_once()
1202+
mock_multithread_prefetch.assert_not_called()
1203+
11861204

11871205
class TfRandomMapAlwaysAddingOne(transforms.TfRandomMap):
11881206

0 commit comments

Comments
 (0)