-
Notifications
You must be signed in to change notification settings - Fork 27
Add adaptive metadata upload batch sizing #659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
84632f2
ded5007
567c830
4a767ab
0f5e9d7
82e4503
996b7fc
2338c94
b884e9e
4c1132c
6ae0a58
1b6356b
3a944ae
6cea20f
a52a0aa
f6dfec9
349c25c
3e0ae68
3c1c817
232ebbf
2c44d83
0453a33
00294fe
5cd02a7
dbf7a68
53ca65a
946134f
841972c
ebde20c
e07343c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,266 @@ | ||
| import logging | ||
| import time | ||
| from dataclasses import dataclass | ||
| import itertools | ||
| from types import SimpleNamespace | ||
| from typing import Callable, Iterable, List, Optional, Sized, TypeVar | ||
|
|
||
| import rich.progress | ||
| from tenacity import wait_exponential | ||
|
|
||
| from dagshub.common.rich_util import get_rich_progress | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| T = TypeVar("T") | ||
|
|
||
| MIN_TARGET_BATCH_TIME_SECONDS = 0.01 | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class AdaptiveBatchConfig: | ||
| max_batch_size: int | ||
| min_batch_size: int | ||
| initial_batch_size: int | ||
| target_batch_time_seconds: float | ||
| batch_growth_factor: int | ||
| retry_backoff_base_seconds: float | ||
| retry_backoff_max_seconds: float | ||
|
|
||
| @classmethod | ||
| def from_values( | ||
| cls, | ||
| max_batch_size: Optional[int] = None, | ||
| min_batch_size: Optional[int] = None, | ||
| initial_batch_size: Optional[int] = None, | ||
| target_batch_time_seconds: Optional[float] = None, | ||
| batch_growth_factor: Optional[int] = None, | ||
| retry_backoff_base_seconds: Optional[float] = None, | ||
| retry_backoff_max_seconds: Optional[float] = None, | ||
| ) -> "AdaptiveBatchConfig": | ||
| import dagshub.common.config as dgs_config | ||
guysmoilov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if max_batch_size is None: | ||
| max_batch_size = dgs_config.dataengine_metadata_upload_batch_size | ||
| if min_batch_size is None: | ||
| min_batch_size = dgs_config.dataengine_metadata_upload_batch_size_min | ||
| if initial_batch_size is None: | ||
| initial_batch_size = dgs_config.dataengine_metadata_upload_batch_size_initial | ||
| if target_batch_time_seconds is None: | ||
| target_batch_time_seconds = dgs_config.dataengine_metadata_upload_target_batch_time_seconds | ||
| if batch_growth_factor is None: | ||
| batch_growth_factor = dgs_config.adaptive_batch_growth_factor | ||
| if retry_backoff_base_seconds is None: | ||
| retry_backoff_base_seconds = dgs_config.adaptive_batch_retry_backoff_base_seconds | ||
| if retry_backoff_max_seconds is None: | ||
| retry_backoff_max_seconds = dgs_config.adaptive_batch_retry_backoff_max_seconds | ||
|
|
||
| normalized_max_batch_size = max(1, max_batch_size) | ||
| normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size)) | ||
| normalized_initial_batch_size = max( | ||
| normalized_min_batch_size, | ||
| min(initial_batch_size, normalized_max_batch_size), | ||
| ) | ||
| normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS) | ||
| return cls( | ||
| max_batch_size=normalized_max_batch_size, | ||
| min_batch_size=normalized_min_batch_size, | ||
| initial_batch_size=normalized_initial_batch_size, | ||
| target_batch_time_seconds=normalized_target_batch_time_seconds, | ||
| batch_growth_factor=max(2, batch_growth_factor), | ||
| retry_backoff_base_seconds=max(0.0, retry_backoff_base_seconds), | ||
| retry_backoff_max_seconds=max(0.0, retry_backoff_max_seconds), | ||
| ) | ||
|
Comment on lines
+43
to
+84
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could all be redone with a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would it not cover it? It is literally just a function that runs at init-time.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because default_factory as far as I understand can't initialize fields based on a function of other fields - it's only a function to initialize each field individually. For interactions between fields that determine field values oyu need.... a constructor
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only value that actually is dependent on another value in that whole constructor is If you really want to clamp |
||
|
|
||
|
|
||
| def _clamp(value: int, lo: int, hi: int) -> int: | ||
| return max(lo, min(hi, value)) | ||
|
|
||
|
|
||
| def _next_batch_after_success( | ||
| batch_size: int, | ||
guysmoilov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| config: AdaptiveBatchConfig, | ||
| bad_batch_size: Optional[int], | ||
| ) -> int: | ||
| """Pick the next batch size after a successful (fast) batch. | ||
|
|
||
| Strategy: | ||
| - If we have a known-bad upper bound, binary-search toward it. | ||
| - Otherwise, multiply by the growth factor. | ||
| - Always guarantee at least +1 progress (so we never stall). | ||
| """ | ||
| if bad_batch_size is not None and batch_size < bad_batch_size: | ||
| # Binary search: try the midpoint between current and bad | ||
| candidate = (batch_size + bad_batch_size) // 2 | ||
| else: | ||
| # No upper bound (or we've already passed it): grow aggressively | ||
| candidate = batch_size * config.batch_growth_factor | ||
|
|
||
| # Must advance by at least 1 to avoid stalling | ||
| candidate = max(candidate, batch_size + 1) | ||
guysmoilov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return _clamp(candidate, config.min_batch_size, config.max_batch_size) | ||
|
|
||
|
|
||
| def _next_batch_after_retryable_failure( | ||
| batch_size: int, | ||
| config: AdaptiveBatchConfig, | ||
| good_batch_size: Optional[int], | ||
| bad_batch_size: Optional[int], | ||
| ) -> int: | ||
| """Pick the next batch size after a failed or slow batch. | ||
|
|
||
| Strategy: | ||
| - If we have a known-good lower bound, binary-search between it and the | ||
| failing size. | ||
| - Otherwise, probe the midpoint between config.min_batch_size and the | ||
| largest allowed size below the failing batch. | ||
| - Must be strictly less than the current size (so we converge downward). | ||
| """ | ||
| if batch_size <= config.min_batch_size: | ||
| return config.min_batch_size | ||
|
|
||
| ceiling = batch_size - 1 # must shrink | ||
| if bad_batch_size is not None: | ||
| ceiling = min(ceiling, bad_batch_size - 1) | ||
|
|
||
| if good_batch_size is not None and good_batch_size < ceiling: | ||
| # Binary search: try the midpoint between good and failing | ||
| candidate = (good_batch_size + ceiling) // 2 | ||
| else: | ||
| # No good lower bound — probe midpoint of the valid range | ||
| candidate = (config.min_batch_size + ceiling) // 2 | ||
|
|
||
| return _clamp(candidate, config.min_batch_size, ceiling) | ||
guysmoilov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: AdaptiveBatchConfig) -> float: | ||
| # SimpleNamespace duck-types the .attempt_number attribute that tenacity's | ||
| # wait strategies read, avoiding the heavier RetryCallState constructor. | ||
| strategy = wait_exponential( | ||
| multiplier=config.retry_backoff_base_seconds, | ||
| min=config.retry_backoff_base_seconds, | ||
| max=config.retry_backoff_max_seconds, | ||
| ) | ||
| retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures)) | ||
| return float(strategy(retry_state)) # type: ignore[arg-type] | ||
guysmoilov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class AdaptiveBatcher: | ||
| """Sends items in adaptively-sized batches, growing on success and shrinking on failure.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| is_retryable: Callable[[Exception], bool], | ||
| config: Optional[AdaptiveBatchConfig] = None, | ||
| progress_label: str = "Uploading", | ||
| ): | ||
| self._config = config if config is not None else AdaptiveBatchConfig.from_values() | ||
| self._is_retryable = is_retryable | ||
| self._progress_label = progress_label | ||
|
|
||
| def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: | ||
guysmoilov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| total: Optional[int] = len(items) if isinstance(items, Sized) else None | ||
| if total == 0: | ||
| return | ||
|
|
||
| config = self._config | ||
| current_batch_size = config.initial_batch_size | ||
| it = iter(items) | ||
guysmoilov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pending: List[T] = [] | ||
|
|
||
| progress = get_rich_progress(rich.progress.MofNCompleteColumn()) | ||
| total_task = progress.add_task( | ||
| f"{self._progress_label} (adaptive batch {config.min_batch_size}-{config.max_batch_size})...", | ||
guysmoilov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| total=total, | ||
| ) | ||
|
|
||
| last_good_batch_size: Optional[int] = None | ||
| last_bad_batch_size: Optional[int] = None | ||
| consecutive_retryable_failures = 0 | ||
| processed = 0 | ||
|
|
||
| with progress: | ||
| while True: | ||
| # Draw from pending (failed-batch leftovers) first, then the source iterator | ||
| batch = pending[:current_batch_size] | ||
| pending = pending[current_batch_size:] | ||
| if len(batch) < current_batch_size: | ||
| batch.extend(itertools.islice(it, current_batch_size - len(batch))) | ||
| if not batch: | ||
| break | ||
| batch_size = len(batch) | ||
|
|
||
| progress.update( | ||
| total_task, | ||
| description=f"{self._progress_label} (batch size {batch_size})...", | ||
| ) | ||
| logger.debug(f"{self._progress_label}: {batch_size} entries...") | ||
|
|
||
| start_time = time.monotonic() | ||
| try: | ||
| operation(batch) | ||
| except Exception as exc: | ||
| if not self._is_retryable(exc): | ||
| logger.error( | ||
| f"{self._progress_label} failed with a non-retryable error; aborting.", | ||
| exc_info=True, | ||
| ) | ||
| raise | ||
|
|
||
| exhausted_shrink = batch_size <= config.min_batch_size and batch_size == current_batch_size | ||
| if exhausted_shrink: | ||
| logger.error( | ||
| f"{self._progress_label} failed at minimum batch size ({batch_size}); aborting.", | ||
| exc_info=True, | ||
| ) | ||
| raise | ||
|
|
||
| consecutive_retryable_failures += 1 | ||
| time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures, config)) | ||
|
|
||
| last_bad_batch_size = ( | ||
| batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) | ||
| ) | ||
| if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size: | ||
| last_good_batch_size = None | ||
| if batch_size < config.min_batch_size: | ||
| # Tail batches below configured min cannot be split further. | ||
| # Retry that exact size once before treating it as exhausted. | ||
| current_batch_size = batch_size | ||
| else: | ||
| current_batch_size = _next_batch_after_retryable_failure( | ||
| batch_size, config, last_good_batch_size, last_bad_batch_size | ||
| ) | ||
| logger.warning( | ||
| f"{self._progress_label} failed for batch size {batch_size} " | ||
| f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." | ||
| ) | ||
| # Re-queue the failed batch items for retry with smaller batch size | ||
| pending = batch + pending | ||
| continue | ||
|
|
||
guysmoilov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| elapsed = time.monotonic() - start_time | ||
| consecutive_retryable_failures = 0 | ||
| processed += batch_size | ||
| progress.update(total_task, advance=batch_size) | ||
|
|
||
| if elapsed <= config.target_batch_time_seconds: | ||
| last_good_batch_size = ( | ||
| batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) | ||
| ) | ||
| # Clear stale bad bound if we succeeded fast at or above it | ||
| if last_bad_batch_size is not None and batch_size >= last_bad_batch_size: | ||
| last_bad_batch_size = None | ||
| current_batch_size = _next_batch_after_success(batch_size, config, last_bad_batch_size) | ||
| else: | ||
| last_bad_batch_size = ( | ||
guysmoilov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) | ||
| ) | ||
| if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size: | ||
| last_good_batch_size = None | ||
| current_batch_size = _next_batch_after_retryable_failure( | ||
| batch_size, config, last_good_batch_size, last_bad_batch_size | ||
| ) | ||
|
|
||
| progress.update(total_task, completed=processed, total=processed, refresh=True) | ||
Uh oh!
There was an error while loading. Please reload this page.