Skip to content

Commit 72c4f9d

Browse files
authored
Merge pull request #659 from DagsHub/feature/adaptive-upload
Add adaptive metadata upload batch sizing
2 parents 65e9f6a + e07343c commit 72c4f9d

File tree

7 files changed

+875
-36
lines changed

7 files changed

+875
-36
lines changed
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import itertools
2+
import logging
3+
import math
4+
import time
5+
from dataclasses import dataclass
6+
from typing import Callable, Iterable, List, Optional, Sized, Tuple, TypeVar
7+
8+
import rich.progress
9+
10+
import dagshub.common.config as dgs_config
11+
from dagshub.common.rich_util import get_rich_progress
12+
13+
logger = logging.getLogger(__name__)
14+
15+
T = TypeVar("T")
16+
17+
MIN_TARGET_BATCH_TIME_SECONDS = 0.01
18+
SOFT_UPPER_LIMIT_MIN_STEP_FRACTION = 0.05
19+
SOFT_UPPER_LIMIT_RETRY_AFTER_SUCCESSES = 3
20+
21+
# Overall strategy:
22+
# - Grow aggressively on fast successes until we hit a slow or failing batch.
23+
# - A slow or failing batch becomes last_bad_batch_size, and a fast batch becomes
24+
# last_fast_batch_size.
25+
# - last_bad_batch_size acts as a soft_upper_limit while the gap to
26+
# last_fast_batch_size is still meaningful, and we probe within that range.
27+
# - Once that gap is below the search resolution, hold the current fast batch size
28+
# instead of micro-searching.
29+
# - Several consecutive fast batches near last_bad_batch_size trigger one more
30+
# probe at soft_upper_limit, since the earlier failure may have been transient.
31+
32+
33+
@dataclass
34+
class AdaptiveBatchConfig:
35+
max_batch_size: int
36+
min_batch_size: int
37+
initial_batch_size: int
38+
target_batch_time_seconds: float
39+
batch_growth_factor: int
40+
retry_backoff_base_seconds: float
41+
retry_backoff_max_seconds: float
42+
43+
@classmethod
44+
def from_values(
45+
cls,
46+
max_batch_size: Optional[int] = None,
47+
min_batch_size: Optional[int] = None,
48+
initial_batch_size: Optional[int] = None,
49+
target_batch_time_seconds: Optional[float] = None,
50+
batch_growth_factor: Optional[int] = None,
51+
retry_backoff_base_seconds: Optional[float] = None,
52+
retry_backoff_max_seconds: Optional[float] = None,
53+
) -> "AdaptiveBatchConfig":
54+
if max_batch_size is None:
55+
max_batch_size = dgs_config.dataengine_metadata_upload_batch_size_max
56+
if min_batch_size is None:
57+
min_batch_size = dgs_config.dataengine_metadata_upload_batch_size_min
58+
if initial_batch_size is None:
59+
initial_batch_size = dgs_config.dataengine_metadata_upload_batch_size_initial
60+
if target_batch_time_seconds is None:
61+
target_batch_time_seconds = dgs_config.dataengine_metadata_upload_target_batch_time_seconds
62+
if batch_growth_factor is None:
63+
batch_growth_factor = dgs_config.adaptive_batch_growth_factor
64+
if retry_backoff_base_seconds is None:
65+
retry_backoff_base_seconds = dgs_config.adaptive_batch_retry_backoff_base_seconds
66+
if retry_backoff_max_seconds is None:
67+
retry_backoff_max_seconds = dgs_config.adaptive_batch_retry_backoff_max_seconds
68+
69+
normalized_max_batch_size = max(1, max_batch_size)
70+
normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size))
71+
normalized_initial_batch_size = max(
72+
normalized_min_batch_size,
73+
min(initial_batch_size, normalized_max_batch_size),
74+
)
75+
normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS)
76+
return cls(
77+
max_batch_size=normalized_max_batch_size,
78+
min_batch_size=normalized_min_batch_size,
79+
initial_batch_size=normalized_initial_batch_size,
80+
target_batch_time_seconds=normalized_target_batch_time_seconds,
81+
batch_growth_factor=max(2, batch_growth_factor),
82+
retry_backoff_base_seconds=max(0.0, retry_backoff_base_seconds),
83+
retry_backoff_max_seconds=max(0.0, retry_backoff_max_seconds),
84+
)
85+
86+
87+
def _clamp(value: int, lo: int, hi: int) -> int:
88+
return max(lo, min(hi, value))
89+
90+
91+
def _next_batch_after_success(
92+
batch_size: int,
93+
config: AdaptiveBatchConfig,
94+
soft_upper_limit: Optional[int],
95+
) -> int:
96+
"""Pick the next batch size after a fast successful batch.
97+
98+
Strategy:
99+
- If we have a previous slow/failing size, binary-search toward it as a soft upper hint.
100+
- Otherwise, multiply by the growth factor.
101+
"""
102+
if soft_upper_limit is not None and batch_size < soft_upper_limit:
103+
# Binary search: try the midpoint between current and the soft upper limit.
104+
candidate = (batch_size + soft_upper_limit) // 2
105+
else:
106+
# No upper hint (or we've already reached it): grow aggressively.
107+
candidate = batch_size * config.batch_growth_factor
108+
109+
return _clamp(candidate, config.min_batch_size, config.max_batch_size)
110+
111+
112+
def _next_batch_after_retryable_failure(
113+
batch_size: int,
114+
config: AdaptiveBatchConfig,
115+
last_fast_batch_size: Optional[int],
116+
soft_upper_limit: Optional[int],
117+
) -> int:
118+
"""Pick the next batch size after a failed or slow batch.
119+
120+
Strategy:
121+
- If we have a known-good lower bound, binary-search between it and the
122+
failing size.
123+
- Otherwise, probe the midpoint between config.min_batch_size and the
124+
largest allowed size below the failing batch.
125+
- Must be strictly less than the current size (so we converge downward).
126+
"""
127+
if batch_size <= config.min_batch_size:
128+
return config.min_batch_size
129+
130+
ceiling = batch_size - 1 # must shrink
131+
if soft_upper_limit is not None:
132+
ceiling = min(ceiling, soft_upper_limit - 1)
133+
134+
if last_fast_batch_size is not None and last_fast_batch_size < ceiling:
135+
# Binary search: try the midpoint between good and failing
136+
candidate = (last_fast_batch_size + ceiling) // 2
137+
else:
138+
# No good lower bound — probe midpoint of the valid range
139+
candidate = (config.min_batch_size + ceiling) // 2
140+
141+
return _clamp(candidate, config.min_batch_size, ceiling)
142+
143+
144+
def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: AdaptiveBatchConfig) -> float:
145+
if config.retry_backoff_base_seconds <= 0.0 or config.retry_backoff_max_seconds <= 0.0:
146+
return 0.0
147+
148+
attempt_number = max(1, consecutive_retryable_failures)
149+
delay = config.retry_backoff_base_seconds * (2 ** (attempt_number - 1))
150+
return min(delay, config.retry_backoff_max_seconds)
151+
152+
153+
def _min_step_size(soft_upper_limit: int) -> int:
154+
return max(1, math.ceil(soft_upper_limit * SOFT_UPPER_LIMIT_MIN_STEP_FRACTION))
155+
156+
157+
def _is_next_step_above_limit(batch_size: int, soft_upper_limit: Optional[int]) -> bool:
158+
if soft_upper_limit is None or batch_size >= soft_upper_limit:
159+
return False
160+
161+
return soft_upper_limit - batch_size <= _min_step_size(soft_upper_limit)
162+
163+
164+
def _update_bounds_after_bad_batch(
165+
batch_size: int,
166+
last_fast_batch_size: Optional[int],
167+
last_bad_batch_size: Optional[int],
168+
) -> Tuple[Optional[int], int]:
169+
updated_last_bad_batch_size = batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size)
170+
if last_fast_batch_size is not None and last_fast_batch_size >= updated_last_bad_batch_size:
171+
last_fast_batch_size = None
172+
return last_fast_batch_size, updated_last_bad_batch_size
173+
174+
175+
class AdaptiveBatcher:
176+
"""Sends items in adaptively-sized batches, growing on success and shrinking on failure."""
177+
178+
def __init__(
179+
self,
180+
is_retryable: Callable[[Exception], bool],
181+
config: Optional[AdaptiveBatchConfig] = None,
182+
progress_label: str = "Uploading",
183+
):
184+
self._config = config if config is not None else AdaptiveBatchConfig.from_values()
185+
self._is_retryable = is_retryable
186+
self._progress_label = progress_label
187+
188+
def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None:
189+
total: Optional[int] = len(items) if isinstance(items, Sized) else None
190+
if total == 0:
191+
return
192+
193+
config = self._config
194+
desired_batch_size = config.initial_batch_size
195+
# Consume the source iterable incrementally across retries and successes.
196+
it = iter(items)
197+
pending: List[T] = []
198+
199+
progress = get_rich_progress(rich.progress.MofNCompleteColumn())
200+
total_task = progress.add_task(f"{self._progress_label}...", total=total)
201+
202+
last_fast_batch_size: Optional[int] = None
203+
last_bad_batch_size: Optional[int] = None
204+
consecutive_retryable_failures = 0
205+
consecutive_fast_successes_near_upper_limit = 0
206+
processed = 0
207+
208+
with progress:
209+
while True:
210+
# Draw from pending (failed-batch leftovers) first, then the source iterator
211+
batch = pending[:desired_batch_size]
212+
pending = pending[desired_batch_size:]
213+
if len(batch) < desired_batch_size:
214+
batch.extend(itertools.islice(it, desired_batch_size - len(batch)))
215+
if not batch:
216+
break
217+
actual_batch_size = len(batch)
218+
219+
progress.update(total_task, description=f"{self._progress_label} (batch size: {actual_batch_size})...")
220+
logger.debug(f"{self._progress_label}: {actual_batch_size} entries...")
221+
222+
start_time = time.monotonic()
223+
try:
224+
operation(batch)
225+
except Exception as exc:
226+
if not self._is_retryable(exc):
227+
logger.error(
228+
f"{self._progress_label} failed with a non-retryable error; aborting.",
229+
exc_info=True,
230+
)
231+
raise
232+
233+
is_short_tail_batch = (
234+
actual_batch_size <= config.min_batch_size and actual_batch_size < desired_batch_size
235+
)
236+
if not is_short_tail_batch and actual_batch_size <= config.min_batch_size:
237+
logger.error(
238+
f"{self._progress_label} failed at minimum batch size ({actual_batch_size}); aborting.",
239+
exc_info=True,
240+
)
241+
raise
242+
243+
consecutive_fast_successes_near_upper_limit = 0
244+
245+
# Exponential backoff
246+
consecutive_retryable_failures += 1
247+
time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures, config))
248+
249+
last_fast_batch_size, last_bad_batch_size = _update_bounds_after_bad_batch(
250+
actual_batch_size, last_fast_batch_size, last_bad_batch_size
251+
)
252+
if is_short_tail_batch:
253+
# A naturally short tail batch cannot be shrunk further in a useful way.
254+
# Retry that exact size once before treating it as exhausted.
255+
desired_batch_size = actual_batch_size
256+
else:
257+
# Binary search downwards
258+
desired_batch_size = _next_batch_after_retryable_failure(
259+
actual_batch_size, config, last_fast_batch_size, last_bad_batch_size
260+
)
261+
logger.warning(
262+
f"{self._progress_label} failed for batch size {actual_batch_size} "
263+
f"({exc.__class__.__name__}: {exc}). Retrying with batch size {desired_batch_size}."
264+
)
265+
# Re-queue the failed batch items for retry with smaller batch size
266+
pending = batch + pending
267+
continue
268+
269+
# On success.
270+
elapsed = time.monotonic() - start_time
271+
consecutive_retryable_failures = 0
272+
processed += actual_batch_size
273+
progress.update(total_task, advance=actual_batch_size)
274+
275+
if elapsed <= config.target_batch_time_seconds:
276+
if last_fast_batch_size is None or actual_batch_size > last_fast_batch_size:
277+
last_fast_batch_size = actual_batch_size
278+
if last_bad_batch_size is not None and actual_batch_size >= last_bad_batch_size:
279+
# A fast success at the upper limit means the last_bad_batch_size is stale.
280+
# We can resume unconstrained growth.
281+
last_bad_batch_size = None
282+
consecutive_fast_successes_near_upper_limit = 0
283+
desired_batch_size = _next_batch_after_success(
284+
actual_batch_size, config, last_bad_batch_size
285+
)
286+
elif _is_next_step_above_limit(actual_batch_size, last_bad_batch_size):
287+
# Once the gap is smaller than our useful search resolution,
288+
# hold the current known-good size and only re-probe the hint
289+
# after a few stable fast successes.
290+
consecutive_fast_successes_near_upper_limit += 1
291+
if consecutive_fast_successes_near_upper_limit >= SOFT_UPPER_LIMIT_RETRY_AFTER_SUCCESSES:
292+
# We've had enough stable fast successes to re-probe the last_bad_batch_size.
293+
desired_batch_size = last_bad_batch_size
294+
consecutive_fast_successes_near_upper_limit = 0
295+
else:
296+
# Hold current size for one more iteration
297+
desired_batch_size = actual_batch_size
298+
else:
299+
# Binary search or unconstrained growth upwards
300+
consecutive_fast_successes_near_upper_limit = 0
301+
desired_batch_size = _next_batch_after_success(
302+
actual_batch_size, config, last_bad_batch_size
303+
)
304+
else:
305+
# Binary search downwards due to a slow batch
306+
consecutive_fast_successes_near_upper_limit = 0
307+
logger.debug(
308+
f"{self._progress_label} batch size {actual_batch_size} took {elapsed:.2f}s "
309+
f"(target {config.target_batch_time_seconds:.2f}s); shrinking."
310+
)
311+
last_fast_batch_size, last_bad_batch_size = _update_bounds_after_bad_batch(
312+
actual_batch_size, last_fast_batch_size, last_bad_batch_size
313+
)
314+
desired_batch_size = _next_batch_after_retryable_failure(
315+
actual_batch_size, config, last_fast_batch_size, last_bad_batch_size
316+
)
317+
318+
progress.update(total_task, completed=processed, total=processed, refresh=True)

dagshub/common/config.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
2-
3-
import appdirs
42
import os
53
from urllib.parse import urlparse
6-
from dagshub import __version__
4+
5+
import appdirs
76
from httpx._client import USER_AGENT
87

8+
from dagshub import __version__
9+
910
logger = logging.getLogger(__name__)
1011

1112
HOST_KEY = "DAGSHUB_CLIENT_HOST"
@@ -58,7 +59,39 @@ def set_host(new_host: str):
5859
recommended_annotate_limit = int(os.environ.get(RECOMMENDED_ANNOTATE_LIMIT_KEY, 1e5))
5960

6061
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE"
61-
dataengine_metadata_upload_batch_size = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000))
62+
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX"
63+
DEFAULT_DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX = 50000
64+
# Fall back to the old `DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE` env var for backwards compatibility.
65+
dataengine_metadata_upload_batch_size_max = int(
66+
os.environ.get(
67+
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY,
68+
os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, DEFAULT_DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX),
69+
)
70+
)
71+
72+
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN"
73+
dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 1))
74+
75+
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_INITIAL"
76+
dataengine_metadata_upload_batch_size_initial = int(
77+
os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY, dataengine_metadata_upload_batch_size_min)
78+
)
79+
80+
DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS"
81+
dataengine_metadata_upload_target_batch_time_seconds = float(
82+
os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, 5.0)
83+
)
84+
85+
DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR"
86+
adaptive_batch_growth_factor = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY, 10))
87+
88+
DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_BASE"
89+
adaptive_batch_retry_backoff_base_seconds = float(
90+
os.environ.get(DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY, 0.25)
91+
)
92+
93+
DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_MAX"
94+
adaptive_batch_retry_backoff_max_seconds = float(os.environ.get(DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY, 60.0))
6295

6396
DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS"
6497
disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ

0 commit comments

Comments
 (0)