Skip to content

Commit ef644bb

Browse files
Internal
PiperOrigin-RevId: 866025320
1 parent 0d302a4 commit ef644bb

File tree

4 files changed

+52
-3
lines changed

4 files changed

+52
-3
lines changed

grain/_src/python/dataset/transformations/interleave.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def set_state(self, state):
218218
iterator = _add_prefetch_and_make_iterator(
219219
self._datasets[index_in_datasets],
220220
interleave_iterator=weakref.ref(self),
221-
start_prefetch=False,
221+
start_prefetch=self._started,
222222
)
223223
iterator.set_state(it_state)
224224
self._iterators_in_use[index_in_cycle] = iterator
@@ -266,6 +266,13 @@ def set_keep_iterators_after_stop_iteration(
266266
# continuing iteration without recreating the iterators.
267267
self._keep_iterators_after_stop_iteration = keep_iterators
268268

269+
def start_prefetch(self) -> None:
270+
self._prefetch_ds_iter.start_prefetch()
271+
for iterator in self._iterators_in_use:
272+
if iterator is not None:
273+
iterator.start_prefetch()
274+
self._started = True
275+
269276
def close(self) -> None:
270277
"""Closes the iterator and shuts down the iterator prefetching."""
271278
if self._closed:
@@ -275,6 +282,9 @@ def close(self) -> None:
275282
for iterator in self._iterators_in_use:
276283
if iterator is not None:
277284
iterator.close()
285+
for index_iterator_pair in self._exhausted_iterators:
286+
if index_iterator_pair is not None:
287+
index_iterator_pair[1].close()
278288

279289
def _initialize_stats(
280290
self, execution_tracking_mode: base.ExecutionTrackingMode

grain/_src/python/dataset/transformations/interleave_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import time
16+
1517
from absl.testing import absltest
1618
from absl.testing import flagsaver
1719
from absl.testing import parameterized
@@ -291,6 +293,22 @@ def test_set_next_index_with_multiple_datasets(self):
291293
):
292294
dataset.set_next_index(ds_iter, 0)
293295

296+
def test_start_prefetch(self):
297+
count = 0
298+
299+
def map_fn(x):
300+
nonlocal count
301+
count += 1
302+
return x
303+
304+
ds = dataset.MapDataset.range(10).to_iter_dataset()
305+
ds = ds.map(map_fn)
306+
ds = interleave.InterleaveIterDataset([ds], cycle_length=1)
307+
ds_iter = ds.__iter__()
308+
ds_iter.start_prefetch()
309+
while count == 0:
310+
time.sleep(0.1)
311+
294312

295313
if __name__ == "__main__":
296314
absltest.main()

grain/_src/python/dataset/transformations/prefetch_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,6 @@ def map(self, features):
980980
@parameterized.parameters(0, 0.5, 30)
981981
def test_prefetch_but_no_read(self, sleep_s):
982982
ds = dataset.MapDataset.source([1, 2, 3]).repeat()
983-
ds = ds.filter(lambda x: x > 3)
984983
ds = ds.to_iter_dataset()
985984
ds = prefetch.multithread_prefetch(ds, num_threads=1, buffer_size=1)
986985
it = ds.__iter__()

grain/_src/python/dataset/transformations/process_prefetch_test.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ def filter(self, element: int) -> bool:
4848
return bool(element % 2)
4949

5050

51+
@dataclasses.dataclass(frozen=True)
52+
class WriteMarker(transforms.Map):
53+
path: str
54+
55+
def map(self, element: int) -> int:
56+
with open(self.path, 'w') as f:
57+
f.write(str(element))
58+
return element
59+
60+
5161
class ProcessPrefetchIterDatasetTest(parameterized.TestCase):
5262

5363
def setUp(self):
@@ -851,10 +861,22 @@ def map(self, features):
851861
if not start_prefetch_calls:
852862
self.assertGreater(time_to_fetch, 1)
853863

864+
def test_start_prefetch_prefetches_without_next_call(self):
865+
marker_file = os.path.join(self.create_tempdir().full_path, 'marker')
866+
ds = dataset.MapDataset.range(10)
867+
ds = ds.map(WriteMarker(marker_file))
868+
ds = ds.to_iter_dataset()
869+
ds = process_prefetch.multiprocess_prefetch(ds, num_workers=1)
870+
it = ds.__iter__()
871+
it.start_prefetch()
872+
873+
# Wait for prefetch to happen.
874+
while not os.path.exists(marker_file):
875+
time.sleep(0.5)
876+
854877
@parameterized.parameters(0, 0.5, 30)
855878
def test_prefetch_but_no_read(self, sleep_s):
856879
ds = dataset.MapDataset.source([1, 2, 3]).repeat()
857-
ds = ds.filter(lambda x: x > 3)
858880
ds = ds.to_iter_dataset()
859881
ds = process_prefetch.multiprocess_prefetch(ds, num_workers=1)
860882
it = ds.__iter__()

0 commit comments

Comments
 (0)