Skip to content

Commit b8a7bbe

Browse files
Nithin Tatikondacopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 838846430
1 parent 4ab3eff commit b8a7bbe

File tree

2 files changed

+323
-54
lines changed

2 files changed

+323
-54
lines changed

grain/_src/python/dataset/transformations/prefetch.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -692,12 +692,17 @@ def multithread_prefetch(
692692
if num_threads == 0:
693693
return ds
694694

695+
dataset_options = _get_dataset_options(ds)
696+
695697
shards = []
696698
for i in range(num_threads):
697-
worker_ds = copy.deepcopy(ds)
698-
_set_slice_iter_dataset(
699-
worker_ds, slice(i, None, num_threads), sequential_slice
700-
)
699+
if num_threads == 1:
700+
worker_ds = ds
701+
else:
702+
worker_ds = copy.deepcopy(ds)
703+
_set_slice_iter_dataset(
704+
worker_ds, slice(i, None, num_threads), sequential_slice
705+
)
701706
shards.append(
702707
_MpContextIterDataset(
703708
worker_ds,
@@ -708,9 +713,13 @@ def multithread_prefetch(
708713
)
709714
)
710715

711-
return interleave.InterleaveIterDataset(
716+
ds = interleave.InterleaveIterDataset(
712717
shards, cycle_length=num_threads, iter_buffer_size=buffer_size
713718
)
719+
# Apply options from parent dataset because interleave dataset does not
720+
# propagate options.
721+
ds = dataset.WithOptionsIterDataset(ds, dataset_options)
722+
return ds
714723

715724

716725
def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool:

0 commit comments

Comments
 (0)