1515
1616from __future__ import annotations
1717
18- from collections .abc import Callable
18+ from collections .abc import Callable , Sequence
19+ import copy
1920import functools
2021from multiprocessing import queues
2122from multiprocessing import synchronize
2223import queue
23- import time
2424from typing import Any , TypeVar
2525
2626from absl import flags
2727import cloudpickle
2828from grain ._src .core import monitoring as grain_monitoring
2929from grain ._src .core .config import config
3030import multiprocessing as mp
31+ from grain ._src .python import grain_logging
3132from grain ._src .python import shared_memory_array
3233from grain ._src .python .dataset import base
3334from grain ._src .python .dataset import dataset
3435from grain ._src .python .dataset import stats as dataset_stats
36+ from grain ._src .python .dataset .transformations import interleave
3537from grain ._src .python .dataset .transformations import prefetch
3638
3739T = TypeVar ("T" )
6163_QUEUE_WAIT_TIMEOUT_S = 1
6264
6365
66+ def _run_all (fns : Sequence [Callable [[], None ]]):
67+ for fn in fns :
68+ fn ()
69+
70+
6471def _parse_debug_flags (debug_flags : dict [str , Any ]):
6572 """Parses debug flags."""
6673 flags .FLAGS ["grain_py_debug_mode" ].present = True
@@ -150,7 +157,6 @@ def __init__(
150157 parent : dataset .IterDataset [T ],
151158 buffer_size : int ,
152159 worker_init_fn : Callable [[], None ] | None = None ,
153- always_report_worker_state : bool = False ,
154160 ):
155161 if buffer_size <= 0 :
156162 raise ValueError (
@@ -159,7 +165,6 @@ def __init__(
159165 super ().__init__ (parent )
160166 self ._buffer_size = buffer_size
161167 self ._worker_init_fn = worker_init_fn
162- self ._always_report_worker_state = always_report_worker_state
163168 _validate_no_nested_process_prefetch (self ._parent )
164169
165170 def __str__ (self ) -> str :
@@ -170,7 +175,6 @@ def __iter__(self) -> dataset.DatasetIterator[T]:
170175 self ._parent ,
171176 self ._buffer_size ,
172177 self ._worker_init_fn ,
173- self ._always_report_worker_state ,
174178 )
175179
176180
@@ -185,7 +189,6 @@ def _put_dataset_elements_in_buffer(
185189 stats_out_queue : queues .Queue [Any ] | None ,
186190 start_profiling_event : synchronize .Event | None ,
187191 stop_profiling_event : synchronize .Event | None ,
188- always_report_worker_state : bool ,
189192 debug_flags : dict [str , Any ],
190193):
191194 """Prefetches elements in a separate process."""
@@ -200,7 +203,6 @@ def _put_dataset_elements_in_buffer(
200203 min_shm_size = it ._ctx .dataset_options .min_shm_size # pylint: disable=protected-access
201204 # Set the stats queue in worker process to send stats to the main process.
202205 it ._stats ._config .stats_out_queue = stats_out_queue # pylint: disable=protected-access
203- last_recorded_state_time = time .time ()
204206 parent_exhausted = False
205207 while not should_stop .is_set ():
206208 if set_state_event .is_set ():
@@ -217,7 +219,6 @@ def _put_dataset_elements_in_buffer(
217219 # error. Wait until set_state_event or should_stop is set.
218220 set_state_event .wait (_PARENT_EXHAUSTED_WAIT_S )
219221 continue
220- now = time .time ()
221222 try :
222223 element = it .__next__ ()
223224 except Exception as e : # pylint: disable=broad-except
@@ -229,14 +230,7 @@ def _put_dataset_elements_in_buffer(
229230 # __next__ method.
230231 if not it ._stats ._config .is_prefetch : # pylint: disable=protected-access
231232 it ._stats .record_bytes_produced (element ) # pylint: disable=protected-access
232- if (
233- always_report_worker_state
234- or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S
235- ):
236- last_recorded_state_time = now
237- buffer .put ((element , it .get_state (), None ))
238- else :
239- buffer .put ((element , None , None ))
233+ buffer .put ((element , it .get_state (), None ))
240234 except Exception as e : # pylint: disable=broad-except
241235 buffer .put ((None , None , e ))
242236
@@ -253,13 +247,11 @@ def __init__(
253247 parent : dataset .IterDataset [T ],
254248 buffer_size : int ,
255249 worker_init_fn : Callable [[], None ] | None = None ,
256- always_report_worker_state : bool = False ,
257250 ):
258251 super ().__init__ ()
259252 self ._iter_parent = parent
260253 self ._buffer_size = buffer_size
261254 self ._worker_init_fn = worker_init_fn
262- self ._always_report_worker_state = always_report_worker_state
263255 # Since the parent iterator is going to be created in each subprocess, and
264256 # the options are propagated during iterator creation, we need to manually
265257 # propagate them.
@@ -282,6 +274,7 @@ def __init__(
282274 self ._iterations_to_skip = 0
283275 self ._set_state_count = 0
284276 self ._exhausted = False
277+ self ._prefetch_ds_iter = None
285278
286279 # pytype: disable=attribute-error
287280 # pylint: disable=protected-access
@@ -348,7 +341,6 @@ def start_prefetch(self) -> None:
348341 stats_out_queue = self ._stats_in_queue ,
349342 start_profiling_event = self ._start_profiling_event ,
350343 stop_profiling_event = self ._stop_profiling_event ,
351- always_report_worker_state = self ._always_report_worker_state ,
352344 debug_flags = dict (
353345 grain_py_debug_mode = config .get_or_default ("py_debug_mode" ),
354346 grain_py_dataset_visualization_output_dir = (
@@ -475,3 +467,73 @@ def set_state(self, state: StateT):
475467
476468 def __str__ (self ) -> str :
477469 return f"ProcessPrefetchDatasetIterator(buffer_size={ self ._buffer_size } )"
470+
471+
472+ def multiprocess_prefetch (
473+ ds : dataset .IterDataset [T ],
474+ num_workers : int = 0 ,
475+ buffer_size : int = 1 ,
476+ worker_init_fn : Callable [[int , int ], None ] | None = None ,
477+ sequential_slice : bool = False ,
478+ ) -> dataset .IterDataset [T ]:
479+ """Uses a multiple processes to prefetch elements ahead of time.
480+
481+ It works by sharding the input dataset into `num_workers` shards, and
482+ interleaving them. Each shard is read by a separate process inside
483+ `InterleaveIterDataset`.
484+
485+ Args:
486+ ds: The parent dataset to prefetch from.
487+ num_workers: The number of processes to use for prefetching. If 0,
488+ prefetching is disabled and this is a no-op.
489+ buffer_size: The size of the prefetch buffer for each process.
490+ worker_init_fn: A function that is called in each worker process.
491+ sequential_slice: Whether to use sequential slicing.
492+
493+ Returns:
494+ `IterDataset` that prefetches elements from `ds` using multiple processes.
495+ """
496+ if num_workers == 0 :
497+ return ds
498+
499+ dataset_options = _get_dataset_options (ds )
500+
501+ shards = []
502+ for i in range (num_workers ):
503+ if num_workers == 1 :
504+ worker_ds = ds
505+ else :
506+ worker_ds = copy .deepcopy (ds )
507+ prefetch ._set_slice_iter_dataset ( # pylint: disable=protected-access
508+ worker_ds , slice (i , None , num_workers ), sequential_slice
509+ )
510+ worker_ds = prefetch ._MpContextIterDataset ( # pylint: disable=protected-access
511+ worker_ds ,
512+ base .MultiprocessingContext (
513+ process_index = i ,
514+ process_count = num_workers ,
515+ ),
516+ )
517+ worker_index_suffix = "" if num_workers == 1 else f" { i } "
518+
519+ worker_init_fns = [
520+ functools .partial (
521+ grain_logging .set_process_identifier_prefix , worker_index_suffix
522+ )
523+ ]
524+ if worker_init_fn is not None :
525+ worker_init_fns .append (functools .partial (worker_init_fn , i , num_workers ))
526+ worker_ds = ProcessPrefetchIterDataset (
527+ worker_ds ,
528+ buffer_size = buffer_size ,
529+ worker_init_fn = functools .partial (_run_all , worker_init_fns ),
530+ )
531+ shards .append (worker_ds )
532+
533+ ds = interleave .InterleaveIterDataset (
534+ shards , cycle_length = num_workers , iter_buffer_size = buffer_size
535+ )
536+ # Apply options from parent dataset because interleave dataset does not
537+ # propagate options.
538+ ds = dataset .WithOptionsIterDataset (ds , dataset_options )
539+ return ds
0 commit comments