File tree Expand file tree Collapse file tree 4 files changed +39
-0
lines changed
grain/_src/python/dataset Expand file tree Collapse file tree 4 files changed +39
-0
lines changed Original file line number Diff line number Diff line change @@ -13,6 +13,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
1313 hiding first batch processing behind model checkpoint recovery.
1414 * Introduces ` grain.experimental.multithread_prefetch ` as an
1515 alternative to multiprocessing prefetch in free-threading Python.
16+ * Switches to multithreading instead of multiprocessing in
17+ ` IterDataset.mp_prefetch ` when free-threaded Python is detected.
1618
1719* Breaking changes:
1820
Original file line number Diff line number Diff line change 5757from grain ._src .core import monitoring as grain_monitoring
5858from grain ._src .core import transforms
5959from grain ._src .core import usage_logging
60+ import multiprocessing as mp
6061from grain ._src .python import checkpointing
6162from grain ._src .python import options as grain_options
6263from grain ._src .python .dataset import base
@@ -1312,6 +1313,18 @@ def mp_prefetch(
13121313 # pylint: disable=g-import-not-at-top
13131314 from grain ._src .python .dataset .transformations import prefetch
13141315 # pylint: enable=g-import-not-at-top
1316+ if mp .is_in_free_threaded_python ():
1317+ if worker_init_fn is not None :
1318+ warnings .warn (
1319+ "Free-threaded Python is used: `mp_prefetch` falls back to"
1320+ " thread-based implementation and `worker_init_fn` is ignored."
1321+ )
1322+ return prefetch .multithread_prefetch (
1323+ self ,
1324+ num_threads = options .num_workers ,
1325+ buffer_size = options .per_worker_buffer_size ,
1326+ sequential_slice = sequential_slice ,
1327+ )
13151328 return prefetch .MultiprocessPrefetchIterDataset (
13161329 self ,
13171330 multiprocessing_options = options ,
Original file line number Diff line number Diff line change 3131from grain ._src .python .dataset import base
3232from grain ._src .python .dataset import dataset
3333from grain ._src .python .dataset import stats as dataset_stats
34+ from grain ._src .python .dataset .transformations import prefetch
3435import grain ._src .python .testing .experimental as test_util
3536from grain .proto import execution_summary_pb2
3637import numpy as np
@@ -1182,6 +1183,21 @@ def test_apply(self, ds):
11821183 ],
11831184 )
11841185
1186+ def test_mp_prefetch_switches_to_threads_for_free_threaded_python (self ):
1187+ ds = dataset .MapDataset .range (15 ).to_iter_dataset ()
1188+ prefetched_ds = ds .mp_prefetch ()
1189+ is_free_threaded = (
1190+ hasattr (sys , "_is_gil_enabled" ) and not sys ._is_gil_enabled ()
1191+ )
1192+ if is_free_threaded :
1193+ self .assertNotIsInstance (
1194+ prefetched_ds , prefetch .MultiprocessPrefetchIterDataset
1195+ )
1196+ else :
1197+ self .assertIsInstance (
1198+ prefetched_ds , prefetch .MultiprocessPrefetchIterDataset
1199+ )
1200+
11851201
11861202class TfRandomMapAlwaysAddingOne (transforms .TfRandomMapTransform ):
11871203
Original file line number Diff line number Diff line change @@ -679,6 +679,13 @@ def map(self, x):
679679 )
680680 iter_ds = iter_ds .mp_prefetch ()
681681
682+ print (
683+ 'nithintati:'
684+ ' test_reports_first_unpicklable_dataset_when_with_multiple_parents'
685+ )
686+ print ('nithintati: ' , iter_ds )
687+ print ('nithintati: ' , mp .is_in_free_threaded_python ())
688+ print ('nithintati: ' , sys .version_info )
682689 with self .assertRaisesRegex (
683690 ValueError ,
684691 r'UnpicklableObject is not picklable' ,
@@ -691,6 +698,7 @@ def map(self, x):
691698 r'Dataset: MapMapDataset\(transform=LeftTransform\) cannot be'
692699 r' pickled!' ,
693700 )
701+ raise ValueError ('nithintati: ' + str (sys .version_info ))
694702
695703 def test_reports_unpicklable_issue_when_only_one_parent_unpicklable (self ):
696704 class UnpicklableObject :
You can’t perform that action at this time.
0 commit comments