Skip to content

Commit 7258616

Browse files
Grain Teamcopybara-github
authored andcommitted
Introduces multiprocess_prefetch as a simpler, more efficient version of the existing MultiprocessPrefetchIterDataset that supports setting state without restarting workers.
PiperOrigin-RevId: 832523331
1 parent 11bd30a commit 7258616

File tree

7 files changed

+1552
-0
lines changed

7 files changed

+1552
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
1313
hiding first batch processing behind model checkpoint recovery.
1414
* Introduces `grain.experimental.multithread_prefetch` as an
1515
alternative to multiprocessing prefetch in free-threading Python.
16+
* Introduces `grain.experimental.multiprocess_prefetch` as an simpler version
17+
of `IterDataset.mp_prefetch` that supports setting state without restarting
18+
workers.
1619

1720
* Breaking changes:
1821

grain/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ py_library(
6464
"//grain/_src/python/dataset/transformations:cache",
6565
"//grain/_src/python/dataset/transformations:interleave",
6666
"//grain/_src/python/dataset/transformations:prefetch_autotune",
67+
"//grain/_src/python/dataset/transformations:process_prefetch",
6768
"//grain/_src/python/dataset/transformations:limit",
6869
"//grain/_src/python/dataset/transformations:packing",
6970
"//grain/_src/python/dataset/transformations:packing_concat_then_split",

grain/_src/python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ py_library(
338338
srcs = ["shared_memory_array.py"],
339339
srcs_version = "PY3",
340340
deps = [
341+
"//grain/_src/core:tree_lib",
341342
"@pypi//numpy:pkg",
342343
],
343344
)

grain/_src/python/dataset/transformations/BUILD

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,40 @@ py_test(
360360
"@pypi//numpy:pkg",
361361
],
362362
)
363+
364+
py_library(
365+
name = "process_prefetch",
366+
srcs = ["process_prefetch.py"],
367+
srcs_version = "PY3",
368+
deps = [
369+
"//grain/_src/core:config",
370+
"//grain/_src/core:tree_lib",
371+
"//grain/_src/python:shared_memory_array",
372+
"//grain/_src/python/dataset",
373+
"//grain/_src/python/dataset:base",
374+
"//grain/_src/python/dataset:stats",
375+
"//grain/_src/python/dataset/transformations:interleave",
376+
"@abseil-py//absl/flags",
377+
"@pypi//cloudpickle:pkg",
378+
"@pypi//numpy:pkg",
379+
],
380+
)
381+
382+
py_test(
383+
name = "process_prefetch_test",
384+
timeout = "long",
385+
srcs = ["process_prefetch_test.py"],
386+
shard_count = 50,
387+
srcs_version = "PY3",
388+
deps = [
389+
":process_prefetch",
390+
"//grain/_src/core:transforms",
391+
"//grain/_src/python:options",
392+
"//grain/_src/python/dataset",
393+
"//grain/_src/python/dataset:base",
394+
"@abseil-py//absl/logging",
395+
"@abseil-py//absl/testing:absltest",
396+
"@abseil-py//absl/testing:parameterized",
397+
"@pypi//numpy:pkg",
398+
],
399+
)

0 commit comments

Comments
 (0)