5050from collections .abc import Awaitable , Callable , Iterable , Iterator , Mapping , Sequence
5151import functools
5252import json
53+ import sys
5354from typing import Any , Generic , TypeVar , Union , cast , overload
5455import warnings
5556
@@ -1334,6 +1335,9 @@ def mp_prefetch(
13341335 multiprocessing resources. We will by default run the cleanup on garbage
13351336 collection, but GC and its sequence is not guaranteed in CPython.
13361337
1338+ NOTE: In free-threaded Python builds, this implementation switches to
1339+ multithreading, ignoring ``worker_init_fn``.
1340+
13371341 Args:
13381342 options: options for the prefetching processes. ``options.num_workers``
13391343 must be greater than or equal to 0. If ``options.num_workers`` is 0,
@@ -1353,10 +1357,24 @@ def mp_prefetch(
13531357 """
13541358
13551359 options = options or grain_options .MultiprocessingOptions (num_workers = 10 )
1356- # Loaded lazily due to a circular dependency (dataset <-> process_prefetch).
1360+ # Loaded lazily due to a circular dependency (dataset <-> process_prefetch)
1361+ # and (dataset <-> prefetch).
13571362 # pylint: disable=g-import-not-at-top
1363+ from grain ._src .python .dataset .transformations import prefetch
13581364 from grain ._src .python .dataset .transformations import process_prefetch
13591365 # pylint: enable=g-import-not-at-top
1366+ if is_in_free_threaded_python ():
1367+ if worker_init_fn is not None :
1368+ warnings .warn (
1369+ "Free-threaded Python is used: `mp_prefetch` falls back to"
1370+ " thread-based implementation and `worker_init_fn` is ignored."
1371+ )
1372+ return prefetch .multithread_prefetch (
1373+ self ,
1374+ num_threads = options .num_workers ,
1375+ buffer_size = options .per_worker_buffer_size ,
1376+ sequential_slice = sequential_slice ,
1377+ )
13601378 return process_prefetch .multiprocess_prefetch (
13611379 self ,
13621380 num_workers = options .num_workers ,
@@ -1835,3 +1853,8 @@ def set_next_index(ds_iter: DatasetIterator, index: int) -> None:
18351853def get_next_index (ds_iter : DatasetIterator ) -> int :
18361854 """Returns the next index for the dataset iterator."""
18371855 return ds_iter ._get_next_index () # pylint: disable=protected-access
1856+
1857+
1858+ def is_in_free_threaded_python () -> bool :
1859+ """Returns whether Python is running in free-threaded mode."""
1860+ return hasattr (sys , "_is_gil_enabled" ) and not sys ._is_gil_enabled () # pylint: disable=protected-access
0 commit comments