Skip to content

Commit 6675bc7

Browse files
Grain Teamcopybara-github
authored andcommitted
mp_prefetch uses multithreading when the GIL is disabled.
PiperOrigin-RevId: 833406009
1 parent 20b4532 commit 6675bc7

File tree

4 files changed

+39
-0
lines changed

4 files changed

+39
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
1313
hiding first batch processing behind model checkpoint recovery.
1414
* Introduces `grain.experimental.multithread_prefetch` as an
1515
alternative to multiprocessing prefetch in free-threading Python.
16+
* Switches to multithreading instead of multiprocessing in
17+
`IterDataset.mp_prefetch` when free-threaded Python is detected.
1618

1719
* Breaking changes:
1820

grain/_src/python/dataset/dataset.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from grain._src.core import monitoring as grain_monitoring
5858
from grain._src.core import transforms
5959
from grain._src.core import usage_logging
60+
import multiprocessing as mp
6061
from grain._src.python import checkpointing
6162
from grain._src.python import options as grain_options
6263
from grain._src.python.dataset import base
@@ -1312,6 +1313,18 @@ def mp_prefetch(
13121313
# pylint: disable=g-import-not-at-top
13131314
from grain._src.python.dataset.transformations import prefetch
13141315
# pylint: enable=g-import-not-at-top
1316+
if mp.is_in_free_threaded_python():
1317+
if worker_init_fn is not None:
1318+
warnings.warn(
1319+
"Free-threaded Python is used: `mp_prefetch` falls back to"
1320+
" thread-based implementation and `worker_init_fn` is ignored."
1321+
)
1322+
return prefetch.multithread_prefetch(
1323+
self,
1324+
num_threads=options.num_workers,
1325+
buffer_size=options.per_worker_buffer_size,
1326+
sequential_slice=sequential_slice,
1327+
)
13151328
return prefetch.MultiprocessPrefetchIterDataset(
13161329
self,
13171330
multiprocessing_options=options,

grain/_src/python/dataset/dataset_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from grain._src.python.dataset import base
3232
from grain._src.python.dataset import dataset
3333
from grain._src.python.dataset import stats as dataset_stats
34+
from grain._src.python.dataset.transformations import prefetch
3435
import grain._src.python.testing.experimental as test_util
3536
from grain.proto import execution_summary_pb2
3637
import numpy as np
@@ -1182,6 +1183,21 @@ def test_apply(self, ds):
11821183
],
11831184
)
11841185

1186+
def test_mp_prefetch_switches_to_threads_for_free_threaded_python(self):
1187+
ds = dataset.MapDataset.range(15).to_iter_dataset()
1188+
prefetched_ds = ds.mp_prefetch()
1189+
is_free_threaded = (
1190+
hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled()
1191+
)
1192+
if is_free_threaded:
1193+
self.assertNotIsInstance(
1194+
prefetched_ds, prefetch.MultiprocessPrefetchIterDataset
1195+
)
1196+
else:
1197+
self.assertIsInstance(
1198+
prefetched_ds, prefetch.MultiprocessPrefetchIterDataset
1199+
)
1200+
11851201

11861202
class TfRandomMapAlwaysAddingOne(transforms.TfRandomMapTransform):
11871203

grain/_src/python/dataset/transformations/prefetch_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,13 @@ def map(self, x):
679679
)
680680
iter_ds = iter_ds.mp_prefetch()
681681

682+
print(
683+
'nithintati:'
684+
' test_reports_first_unpicklable_dataset_when_with_multiple_parents'
685+
)
686+
print('nithintati: ', iter_ds)
687+
print('nithintati: ', mp.is_in_free_threaded_python())
688+
print('nithintati: ', sys.version_info)
682689
with self.assertRaisesRegex(
683690
ValueError,
684691
r'UnpicklableObject is not picklable',
@@ -691,6 +698,7 @@ def map(self, x):
691698
r'Dataset: MapMapDataset\(transform=LeftTransform\) cannot be'
692699
r' pickled!',
693700
)
701+
raise ValueError('nithintati: ' + str(sys.version_info))
694702

695703
def test_reports_unpicklable_issue_when_only_one_parent_unpicklable(self):
696704
class UnpicklableObject:

0 commit comments

Comments
 (0)