Skip to content

Commit 41643b4

Browse files
Internal
PiperOrigin-RevId: 869789772
1 parent 0d302a4 commit 41643b4

File tree

4 files changed

+83
-62
lines changed

4 files changed

+83
-62
lines changed

grain/_src/python/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ py_library(
222222
name = "options",
223223
srcs = ["options.py"],
224224
srcs_version = "PY3",
225-
deps = ["@abseil-py//absl/logging"],
225+
deps = [
226+
"@abseil-py//absl/logging",
227+
],
226228
)
227229

228230
py_test(

grain/_src/python/dataset/transformations/prefetch.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import typing
2727
from typing import Any, Optional, Protocol, TypeVar
2828

29+
from absl import logging
2930
from concurrent import futures
3031
from grain._src.core import monitoring as grain_monitoring
3132
from grain._src.python import options as grain_options
@@ -144,14 +145,16 @@ def __init__(
144145
self._next_buffered_index = 0
145146
self._buffer = collections.deque()
146147
self._lock = threading.Lock()
147-
self._prefetch_buffer_size = (
148-
read_options.prefetch_buffer_size if read_options.num_threads > 0 else 0
149-
)
150-
self._num_threads = read_options.num_threads
148+
149+
assert isinstance(read_options.num_threads, int)
150+
assert isinstance(read_options.prefetch_buffer_size, int)
151+
self._target_num_threads = read_options.num_threads
152+
self._target_prefetch_buffer_size = read_options.prefetch_buffer_size
153+
151154
self._allow_nones = allow_nones
152-
if self._prefetch_buffer_size > 0:
155+
if self._target_prefetch_buffer_size > 0 and self._target_num_threads > 0:
153156
self._executor = futures.ThreadPoolExecutor(
154-
self._num_threads, thread_name_prefix="grain-prefetch"
157+
self._target_num_threads, thread_name_prefix="grain-prefetch"
155158
)
156159

157160
def _initialize_stats(
@@ -195,7 +198,10 @@ def __next__(self) -> T:
195198
if self._next_returned_index == self._dataset_length:
196199
break
197200
with self._lock, timer:
198-
if self._prefetch_buffer_size > 0:
201+
if (
202+
self._target_prefetch_buffer_size > 0
203+
and self._target_num_threads > 0
204+
):
199205
if not self._buffer:
200206
# Fill the buffer on the first iteration.
201207
self._fill_buffer()
@@ -237,11 +243,11 @@ def set_state(self, state):
237243
f"Checkpoint `next_index` {self._next_returned_index} is out of"
238244
f" range for dataset of length {self._dataset_length}."
239245
)
240-
if self._prefetch_buffer_size > 0:
241-
# Cancel all pending futures in the buffer.
242-
while self._buffer:
243-
future = self._buffer.popleft()
244-
future.cancel()
246+
247+
# Cancel all pending futures in the buffer.
248+
while self._buffer:
249+
future = self._buffer.popleft()
250+
future.cancel()
245251

246252
def _get_next_index(self) -> int:
247253
return self._next_returned_index
@@ -255,34 +261,33 @@ def __str__(self) -> str:
255261
f" allow_nones={self._allow_nones})"
256262
)
257263

258-
def set_prefetch_buffer_size(self, buffer_size: int):
259-
self._prefetch_buffer_size = buffer_size
264+
def _set_prefetch_buffer_size(self, buffer_size: int):
265+
self._target_prefetch_buffer_size = buffer_size
260266
# The executor is created in the constructor only if the prefetch buffer
261267
# size is greater than 0. If the user changes the prefetch buffer size, we
262268
# need to create or destroy the executor accordingly.
263-
if self._prefetch_buffer_size > 0 and not hasattr(self, "_executor"):
264-
if self._num_threads == 0:
265-
raise ValueError(
266-
"num_threads must be greater than 0 when prefetch buffer size is"
267-
" greater than 0."
268-
)
269+
if (
270+
self._target_prefetch_buffer_size > 0
271+
and self._target_num_threads > 0
272+
and not hasattr(self, "_executor")
273+
):
269274
self._executor = futures.ThreadPoolExecutor(
270-
self._num_threads, thread_name_prefix="grain-prefetch"
275+
self._target_num_threads, thread_name_prefix="grain-prefetch"
271276
)
272-
elif self._prefetch_buffer_size == 0 and hasattr(self, "_executor"):
277+
elif self._target_prefetch_buffer_size == 0 and hasattr(self, "_executor"):
273278
self._executor.shutdown()
274279
delattr(self, "_executor")
275280

276-
def set_num_threads(self, num_threads: int) -> None:
277-
self._num_threads = num_threads
281+
def _set_num_threads(self, num_threads: int) -> None:
282+
self._target_num_threads = num_threads
278283
old_executor = None
279284
# Accounts for the case where the executor does not exit. This can
280285
# happen if the prefetch buffer size is set to 0.
281286
if hasattr(self, "_executor"):
282287
old_executor = self._executor
283-
if self._num_threads > 0:
288+
if self._target_num_threads > 0 and self._target_prefetch_buffer_size > 0:
284289
self._executor = futures.ThreadPoolExecutor(
285-
self._num_threads, thread_name_prefix="grain-prefetch"
290+
self._target_num_threads, thread_name_prefix="grain-prefetch"
286291
)
287292
else:
288293
delattr(self, "_executor")
@@ -293,7 +298,7 @@ def set_num_threads(self, num_threads: int) -> None:
293298

294299
def _fill_buffer(self):
295300
while (
296-
len(self._buffer) < self._prefetch_buffer_size
301+
len(self._buffer) < self._target_prefetch_buffer_size
297302
and self._next_buffered_index < self._dataset_length
298303
):
299304
# Note that we trigger creation of `_stats` in this (single) thread, it is
@@ -307,7 +312,7 @@ def _fill_buffer(self):
307312
self._next_buffered_index += 1
308313

309314
def start_prefetch(self):
310-
if self._prefetch_buffer_size > 0:
315+
if self._target_prefetch_buffer_size > 0 and self._target_num_threads > 0:
311316
self._fill_buffer()
312317

313318
def close(self) -> None:

grain/_src/python/dataset/transformations/prefetch_test.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ def test_set_prefetch_buffer_size_0_to_positive(self):
158158

159159
# With prefetch_buffer_size=0, executor is not created.
160160
self.assertFalse(hasattr(ds_iter, '_executor'))
161-
self.assertEqual(ds_iter._prefetch_buffer_size, 0)
161+
self.assertEqual(ds_iter._target_prefetch_buffer_size, 0)
162162
self.assertEqual(next(ds_iter), 0)
163163

164164
# Setting prefetch_buffer_size to 2.
165-
ds_iter.set_prefetch_buffer_size(2)
166-
self.assertEqual(ds_iter._prefetch_buffer_size, 2)
165+
ds_iter._set_prefetch_buffer_size(2)
166+
self.assertEqual(ds_iter._target_prefetch_buffer_size, 2)
167167
self.assertEqual(next(ds_iter), 1)
168168
self.assertTrue(hasattr(ds_iter, '_executor'))
169169
self.assertLen(ds_iter._buffer, 2)
@@ -178,13 +178,13 @@ def test_set_prefetch_buffer_size_positive_to_0(self):
178178
self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator)
179179
ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter)
180180

181-
self.assertEqual(ds_iter._prefetch_buffer_size, 2)
181+
self.assertEqual(ds_iter._target_prefetch_buffer_size, 2)
182182
self.assertEqual(next(ds_iter), 0)
183183
self.assertLen(ds_iter._buffer, 2)
184184

185185
# Setting prefetch_buffer_size to 0.
186-
ds_iter.set_prefetch_buffer_size(0)
187-
self.assertEqual(ds_iter._prefetch_buffer_size, 0)
186+
ds_iter._set_prefetch_buffer_size(0)
187+
self.assertEqual(ds_iter._target_prefetch_buffer_size, 0)
188188
# Should consume buffer first.
189189
self.assertEqual(next(ds_iter), 1)
190190
self.assertLen(ds_iter._buffer, 1)
@@ -202,13 +202,13 @@ def test_set_prefetch_buffer_size_increase(self):
202202
self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator)
203203
ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter)
204204

205-
self.assertEqual(ds_iter._prefetch_buffer_size, 1)
205+
self.assertEqual(ds_iter._target_prefetch_buffer_size, 1)
206206
self.assertEqual(next(ds_iter), 0)
207207
self.assertLen(ds_iter._buffer, 1)
208208

209209
# Setting prefetch_buffer_size to 2.
210-
ds_iter.set_prefetch_buffer_size(2)
211-
self.assertEqual(ds_iter._prefetch_buffer_size, 2)
210+
ds_iter._set_prefetch_buffer_size(2)
211+
self.assertEqual(ds_iter._target_prefetch_buffer_size, 2)
212212
self.assertEqual(next(ds_iter), 1)
213213
self.assertLen(ds_iter._buffer, 2)
214214
self.assertEqual(next(ds_iter), 2)
@@ -222,13 +222,13 @@ def test_set_prefetch_buffer_size_decrease(self):
222222
self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator)
223223
ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter)
224224

225-
self.assertEqual(ds_iter._prefetch_buffer_size, 2)
225+
self.assertEqual(ds_iter._target_prefetch_buffer_size, 2)
226226
self.assertEqual(next(ds_iter), 0)
227227
self.assertLen(ds_iter._buffer, 2)
228228

229229
# Setting prefetch_buffer_size to 1.
230-
ds_iter.set_prefetch_buffer_size(1)
231-
self.assertEqual(ds_iter._prefetch_buffer_size, 1)
230+
ds_iter._set_prefetch_buffer_size(1)
231+
self.assertEqual(ds_iter._target_prefetch_buffer_size, 1)
232232
self.assertEqual(next(ds_iter), 1)
233233
self.assertLen(ds_iter._buffer, 1)
234234
self.assertEqual(next(ds_iter), 2)
@@ -321,15 +321,17 @@ def test_set_num_threads_decrease_threads(self):
321321
ds_iter = iter(self.prefetch_lazy_iter_ds)
322322
self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator)
323323
ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter)
324-
self.assertEqual(ds_iter._num_threads, options.ReadOptions().num_threads)
324+
self.assertEqual(
325+
ds_iter._target_num_threads, options.ReadOptions().num_threads
326+
)
325327
self.assertEqual(
326328
ds_iter._executor._max_workers, options.ReadOptions().num_threads
327329
)
328330
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5)))
329331

330332
# Decrease threads
331-
ds_iter.set_num_threads(5)
332-
self.assertEqual(ds_iter._num_threads, 5)
333+
ds_iter._set_num_threads(5)
334+
self.assertEqual(ds_iter._target_num_threads, 5)
333335
self.assertEqual(ds_iter._executor._max_workers, 5)
334336
self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20)))
335337

@@ -340,28 +342,30 @@ def test_set_num_threads_increase_threads(self):
340342
ds_iter = iter(ds)
341343
self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator)
342344
ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter)
343-
self.assertEqual(ds_iter._num_threads, 5)
345+
self.assertEqual(ds_iter._target_num_threads, 5)
344346
self.assertEqual(ds_iter._executor._max_workers, 5)
345347
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5)))
346348

347349
# Increase threads
348-
ds_iter.set_num_threads(10)
349-
self.assertEqual(ds_iter._num_threads, 10)
350+
ds_iter._set_num_threads(10)
351+
self.assertEqual(ds_iter._target_num_threads, 10)
350352
self.assertEqual(ds_iter._executor._max_workers, 10)
351353
self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20)))
352354

353355
def test_set_num_threads_decrease_to_zero(self):
354356
ds_iter = iter(self.prefetch_lazy_iter_ds)
355357
self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator)
356358
ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter)
357-
self.assertEqual(ds_iter._num_threads, options.ReadOptions().num_threads)
359+
self.assertEqual(
360+
ds_iter._target_num_threads, options.ReadOptions().num_threads
361+
)
358362
self.assertEqual(
359363
ds_iter._executor._max_workers, options.ReadOptions().num_threads
360364
)
361365
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5)))
362366
# Decrease threads to 0
363-
ds_iter.set_num_threads(0)
364-
self.assertEqual(ds_iter._num_threads, 0)
367+
ds_iter._set_num_threads(0)
368+
self.assertEqual(ds_iter._target_num_threads, 0)
365369
self.assertFalse(hasattr(ds_iter, '_executor'))
366370
self.assertEqual([next(ds_iter) for _ in range(15)], list(range(5, 20)))
367371

@@ -370,14 +374,14 @@ def test_set_num_threads_increase_from_zero(self):
370374
self.assertIsInstance(ds_iter, prefetch.PrefetchDatasetIterator)
371375
ds_iter = cast(prefetch.PrefetchDatasetIterator, ds_iter)
372376
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5)))
373-
ds_iter.set_num_threads(0)
374-
self.assertEqual(ds_iter._num_threads, 0)
377+
ds_iter._set_num_threads(0)
378+
self.assertEqual(ds_iter._target_num_threads, 0)
375379
self.assertFalse(hasattr(ds_iter, '_executor'))
376380
self.assertEqual([next(ds_iter) for _ in range(5)], list(range(5, 10)))
377381

378382
# Increase threads from 0
379-
ds_iter.set_num_threads(5)
380-
self.assertEqual(ds_iter._num_threads, 5)
383+
ds_iter._set_num_threads(5)
384+
self.assertEqual(ds_iter._target_num_threads, 5)
381385
self.assertEqual(ds_iter._executor._max_workers, 5)
382386
self.assertEqual([next(ds_iter) for _ in range(10)], list(range(10, 20)))
383387

grain/_src/python/options.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Dataclasses for holdings options."""
15+
from __future__ import annotations
16+
1517
import dataclasses
1618

1719
from absl import logging
20+
class AutotuneParameter:
21+
22+
def __init__(self, *args, **kwargs):
23+
raise NotImplementedError
1824

1925

2026
@dataclasses.dataclass(slots=True)
@@ -41,25 +47,29 @@ class ReadOptions:
4147
# benchmarks reading from remote hard drives.
4248
# These values should work well for datasets with elements between 1 and
4349
# 10 KiB on disk.
44-
num_threads: int = 16
45-
prefetch_buffer_size: int = 500
50+
num_threads: int | AutotuneParameter = 16
51+
prefetch_buffer_size: int | AutotuneParameter = 500
4652

4753
def __post_init__(self):
48-
if self.num_threads < 0:
54+
if isinstance(self.num_threads, int) and self.num_threads < 0:
4955
raise ValueError(
5056
f'num_threads must be non-negative, got {self.num_threads}'
5157
)
52-
if self.prefetch_buffer_size < 0:
58+
59+
if (
60+
isinstance(self.prefetch_buffer_size, int)
61+
and self.prefetch_buffer_size < 0
62+
):
5363
raise ValueError(
5464
'prefetch_buffer_size must be non-negative, got'
5565
f' {self.prefetch_buffer_size}'
5666
)
67+
5768
# Avoid warning when setting prefetch_buffer_size=0, since this is commonly
5869
# used to disable prefetching.
59-
if (
60-
self.prefetch_buffer_size < self.num_threads
61-
and self.prefetch_buffer_size != 0
62-
):
70+
buffer_size = int(self.prefetch_buffer_size)
71+
num_threads = int(self.num_threads)
72+
if buffer_size < num_threads and buffer_size != 0:
6373
logging.warning(
6474
'prefetch_buffer_size=%s is smaller than num_threads=%s. This will'
6575
' limit the number of threads that can actually be used in parallel'

0 commit comments

Comments
 (0)