|
| 1 | +import itertools |
| 2 | +import logging |
| 3 | +import math |
| 4 | +import time |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import Callable, Iterable, List, Optional, Sized, Tuple, TypeVar |
| 7 | + |
| 8 | +import rich.progress |
| 9 | + |
| 10 | +import dagshub.common.config as dgs_config |
| 11 | +from dagshub.common.rich_util import get_rich_progress |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | +T = TypeVar("T") |
| 16 | + |
| 17 | +MIN_TARGET_BATCH_TIME_SECONDS = 0.01 |
| 18 | +SOFT_UPPER_LIMIT_MIN_STEP_FRACTION = 0.05 |
| 19 | +SOFT_UPPER_LIMIT_RETRY_AFTER_SUCCESSES = 3 |
| 20 | + |
| 21 | +# Overall strategy: |
| 22 | +# - Grow aggressively on fast successes until we hit a slow or failing batch. |
| 23 | +# - A slow or failing batch becomes last_bad_batch_size, and a fast batch becomes |
| 24 | +# last_fast_batch_size. |
| 25 | +# - last_bad_batch_size acts as a soft_upper_limit while the gap to |
| 26 | +# last_fast_batch_size is still meaningful, and we probe within that range. |
| 27 | +# - Once that gap is below the search resolution, hold the current fast batch size |
| 28 | +# instead of micro-searching. |
| 29 | +# - Several consecutive fast batches near last_bad_batch_size trigger one more |
| 30 | +# probe at soft_upper_limit, since the earlier failure may have been transient. |
| 31 | + |
| 32 | + |
| 33 | +@dataclass |
| 34 | +class AdaptiveBatchConfig: |
| 35 | + max_batch_size: int |
| 36 | + min_batch_size: int |
| 37 | + initial_batch_size: int |
| 38 | + target_batch_time_seconds: float |
| 39 | + batch_growth_factor: int |
| 40 | + retry_backoff_base_seconds: float |
| 41 | + retry_backoff_max_seconds: float |
| 42 | + |
| 43 | + @classmethod |
| 44 | + def from_values( |
| 45 | + cls, |
| 46 | + max_batch_size: Optional[int] = None, |
| 47 | + min_batch_size: Optional[int] = None, |
| 48 | + initial_batch_size: Optional[int] = None, |
| 49 | + target_batch_time_seconds: Optional[float] = None, |
| 50 | + batch_growth_factor: Optional[int] = None, |
| 51 | + retry_backoff_base_seconds: Optional[float] = None, |
| 52 | + retry_backoff_max_seconds: Optional[float] = None, |
| 53 | + ) -> "AdaptiveBatchConfig": |
| 54 | + if max_batch_size is None: |
| 55 | + max_batch_size = dgs_config.dataengine_metadata_upload_batch_size_max |
| 56 | + if min_batch_size is None: |
| 57 | + min_batch_size = dgs_config.dataengine_metadata_upload_batch_size_min |
| 58 | + if initial_batch_size is None: |
| 59 | + initial_batch_size = dgs_config.dataengine_metadata_upload_batch_size_initial |
| 60 | + if target_batch_time_seconds is None: |
| 61 | + target_batch_time_seconds = dgs_config.dataengine_metadata_upload_target_batch_time_seconds |
| 62 | + if batch_growth_factor is None: |
| 63 | + batch_growth_factor = dgs_config.adaptive_batch_growth_factor |
| 64 | + if retry_backoff_base_seconds is None: |
| 65 | + retry_backoff_base_seconds = dgs_config.adaptive_batch_retry_backoff_base_seconds |
| 66 | + if retry_backoff_max_seconds is None: |
| 67 | + retry_backoff_max_seconds = dgs_config.adaptive_batch_retry_backoff_max_seconds |
| 68 | + |
| 69 | + normalized_max_batch_size = max(1, max_batch_size) |
| 70 | + normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size)) |
| 71 | + normalized_initial_batch_size = max( |
| 72 | + normalized_min_batch_size, |
| 73 | + min(initial_batch_size, normalized_max_batch_size), |
| 74 | + ) |
| 75 | + normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS) |
| 76 | + return cls( |
| 77 | + max_batch_size=normalized_max_batch_size, |
| 78 | + min_batch_size=normalized_min_batch_size, |
| 79 | + initial_batch_size=normalized_initial_batch_size, |
| 80 | + target_batch_time_seconds=normalized_target_batch_time_seconds, |
| 81 | + batch_growth_factor=max(2, batch_growth_factor), |
| 82 | + retry_backoff_base_seconds=max(0.0, retry_backoff_base_seconds), |
| 83 | + retry_backoff_max_seconds=max(0.0, retry_backoff_max_seconds), |
| 84 | + ) |
| 85 | + |
| 86 | + |
| 87 | +def _clamp(value: int, lo: int, hi: int) -> int: |
| 88 | + return max(lo, min(hi, value)) |
| 89 | + |
| 90 | + |
| 91 | +def _next_batch_after_success( |
| 92 | + batch_size: int, |
| 93 | + config: AdaptiveBatchConfig, |
| 94 | + soft_upper_limit: Optional[int], |
| 95 | +) -> int: |
| 96 | + """Pick the next batch size after a fast successful batch. |
| 97 | +
|
| 98 | + Strategy: |
| 99 | + - If we have a previous slow/failing size, binary-search toward it as a soft upper hint. |
| 100 | + - Otherwise, multiply by the growth factor. |
| 101 | + """ |
| 102 | + if soft_upper_limit is not None and batch_size < soft_upper_limit: |
| 103 | + # Binary search: try the midpoint between current and the soft upper limit. |
| 104 | + candidate = (batch_size + soft_upper_limit) // 2 |
| 105 | + else: |
| 106 | + # No upper hint (or we've already reached it): grow aggressively. |
| 107 | + candidate = batch_size * config.batch_growth_factor |
| 108 | + |
| 109 | + return _clamp(candidate, config.min_batch_size, config.max_batch_size) |
| 110 | + |
| 111 | + |
| 112 | +def _next_batch_after_retryable_failure( |
| 113 | + batch_size: int, |
| 114 | + config: AdaptiveBatchConfig, |
| 115 | + last_fast_batch_size: Optional[int], |
| 116 | + soft_upper_limit: Optional[int], |
| 117 | +) -> int: |
| 118 | + """Pick the next batch size after a failed or slow batch. |
| 119 | +
|
| 120 | + Strategy: |
| 121 | + - If we have a known-good lower bound, binary-search between it and the |
| 122 | + failing size. |
| 123 | + - Otherwise, probe the midpoint between config.min_batch_size and the |
| 124 | + largest allowed size below the failing batch. |
| 125 | + - Must be strictly less than the current size (so we converge downward). |
| 126 | + """ |
| 127 | + if batch_size <= config.min_batch_size: |
| 128 | + return config.min_batch_size |
| 129 | + |
| 130 | + ceiling = batch_size - 1 # must shrink |
| 131 | + if soft_upper_limit is not None: |
| 132 | + ceiling = min(ceiling, soft_upper_limit - 1) |
| 133 | + |
| 134 | + if last_fast_batch_size is not None and last_fast_batch_size < ceiling: |
| 135 | + # Binary search: try the midpoint between good and failing |
| 136 | + candidate = (last_fast_batch_size + ceiling) // 2 |
| 137 | + else: |
| 138 | + # No good lower bound — probe midpoint of the valid range |
| 139 | + candidate = (config.min_batch_size + ceiling) // 2 |
| 140 | + |
| 141 | + return _clamp(candidate, config.min_batch_size, ceiling) |
| 142 | + |
| 143 | + |
| 144 | +def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: AdaptiveBatchConfig) -> float: |
| 145 | + if config.retry_backoff_base_seconds <= 0.0 or config.retry_backoff_max_seconds <= 0.0: |
| 146 | + return 0.0 |
| 147 | + |
| 148 | + attempt_number = max(1, consecutive_retryable_failures) |
| 149 | + delay = config.retry_backoff_base_seconds * (2 ** (attempt_number - 1)) |
| 150 | + return min(delay, config.retry_backoff_max_seconds) |
| 151 | + |
| 152 | + |
| 153 | +def _min_step_size(soft_upper_limit: int) -> int: |
| 154 | + return max(1, math.ceil(soft_upper_limit * SOFT_UPPER_LIMIT_MIN_STEP_FRACTION)) |
| 155 | + |
| 156 | + |
| 157 | +def _is_next_step_above_limit(batch_size: int, soft_upper_limit: Optional[int]) -> bool: |
| 158 | + if soft_upper_limit is None or batch_size >= soft_upper_limit: |
| 159 | + return False |
| 160 | + |
| 161 | + return soft_upper_limit - batch_size <= _min_step_size(soft_upper_limit) |
| 162 | + |
| 163 | + |
| 164 | +def _update_bounds_after_bad_batch( |
| 165 | + batch_size: int, |
| 166 | + last_fast_batch_size: Optional[int], |
| 167 | + last_bad_batch_size: Optional[int], |
| 168 | +) -> Tuple[Optional[int], int]: |
| 169 | + updated_last_bad_batch_size = batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) |
| 170 | + if last_fast_batch_size is not None and last_fast_batch_size >= updated_last_bad_batch_size: |
| 171 | + last_fast_batch_size = None |
| 172 | + return last_fast_batch_size, updated_last_bad_batch_size |
| 173 | + |
| 174 | + |
| 175 | +class AdaptiveBatcher: |
| 176 | + """Sends items in adaptively-sized batches, growing on success and shrinking on failure.""" |
| 177 | + |
| 178 | + def __init__( |
| 179 | + self, |
| 180 | + is_retryable: Callable[[Exception], bool], |
| 181 | + config: Optional[AdaptiveBatchConfig] = None, |
| 182 | + progress_label: str = "Uploading", |
| 183 | + ): |
| 184 | + self._config = config if config is not None else AdaptiveBatchConfig.from_values() |
| 185 | + self._is_retryable = is_retryable |
| 186 | + self._progress_label = progress_label |
| 187 | + |
| 188 | + def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: |
| 189 | + total: Optional[int] = len(items) if isinstance(items, Sized) else None |
| 190 | + if total == 0: |
| 191 | + return |
| 192 | + |
| 193 | + config = self._config |
| 194 | + desired_batch_size = config.initial_batch_size |
| 195 | + # Consume the source iterable incrementally across retries and successes. |
| 196 | + it = iter(items) |
| 197 | + pending: List[T] = [] |
| 198 | + |
| 199 | + progress = get_rich_progress(rich.progress.MofNCompleteColumn()) |
| 200 | + total_task = progress.add_task(f"{self._progress_label}...", total=total) |
| 201 | + |
| 202 | + last_fast_batch_size: Optional[int] = None |
| 203 | + last_bad_batch_size: Optional[int] = None |
| 204 | + consecutive_retryable_failures = 0 |
| 205 | + consecutive_fast_successes_near_upper_limit = 0 |
| 206 | + processed = 0 |
| 207 | + |
| 208 | + with progress: |
| 209 | + while True: |
| 210 | + # Draw from pending (failed-batch leftovers) first, then the source iterator |
| 211 | + batch = pending[:desired_batch_size] |
| 212 | + pending = pending[desired_batch_size:] |
| 213 | + if len(batch) < desired_batch_size: |
| 214 | + batch.extend(itertools.islice(it, desired_batch_size - len(batch))) |
| 215 | + if not batch: |
| 216 | + break |
| 217 | + actual_batch_size = len(batch) |
| 218 | + |
| 219 | + progress.update(total_task, description=f"{self._progress_label} (batch size: {actual_batch_size})...") |
| 220 | + logger.debug(f"{self._progress_label}: {actual_batch_size} entries...") |
| 221 | + |
| 222 | + start_time = time.monotonic() |
| 223 | + try: |
| 224 | + operation(batch) |
| 225 | + except Exception as exc: |
| 226 | + if not self._is_retryable(exc): |
| 227 | + logger.error( |
| 228 | + f"{self._progress_label} failed with a non-retryable error; aborting.", |
| 229 | + exc_info=True, |
| 230 | + ) |
| 231 | + raise |
| 232 | + |
| 233 | + is_short_tail_batch = ( |
| 234 | + actual_batch_size <= config.min_batch_size and actual_batch_size < desired_batch_size |
| 235 | + ) |
| 236 | + if not is_short_tail_batch and actual_batch_size <= config.min_batch_size: |
| 237 | + logger.error( |
| 238 | + f"{self._progress_label} failed at minimum batch size ({actual_batch_size}); aborting.", |
| 239 | + exc_info=True, |
| 240 | + ) |
| 241 | + raise |
| 242 | + |
| 243 | + consecutive_fast_successes_near_upper_limit = 0 |
| 244 | + |
| 245 | + # Exponential backoff |
| 246 | + consecutive_retryable_failures += 1 |
| 247 | + time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures, config)) |
| 248 | + |
| 249 | + last_fast_batch_size, last_bad_batch_size = _update_bounds_after_bad_batch( |
| 250 | + actual_batch_size, last_fast_batch_size, last_bad_batch_size |
| 251 | + ) |
| 252 | + if is_short_tail_batch: |
| 253 | + # A naturally short tail batch cannot be shrunk further in a useful way. |
| 254 | + # Retry that exact size once before treating it as exhausted. |
| 255 | + desired_batch_size = actual_batch_size |
| 256 | + else: |
| 257 | + # Binary search downwards |
| 258 | + desired_batch_size = _next_batch_after_retryable_failure( |
| 259 | + actual_batch_size, config, last_fast_batch_size, last_bad_batch_size |
| 260 | + ) |
| 261 | + logger.warning( |
| 262 | + f"{self._progress_label} failed for batch size {actual_batch_size} " |
| 263 | + f"({exc.__class__.__name__}: {exc}). Retrying with batch size {desired_batch_size}." |
| 264 | + ) |
| 265 | + # Re-queue the failed batch items for retry with smaller batch size |
| 266 | + pending = batch + pending |
| 267 | + continue |
| 268 | + |
| 269 | + # On success. |
| 270 | + elapsed = time.monotonic() - start_time |
| 271 | + consecutive_retryable_failures = 0 |
| 272 | + processed += actual_batch_size |
| 273 | + progress.update(total_task, advance=actual_batch_size) |
| 274 | + |
| 275 | + if elapsed <= config.target_batch_time_seconds: |
| 276 | + if last_fast_batch_size is None or actual_batch_size > last_fast_batch_size: |
| 277 | + last_fast_batch_size = actual_batch_size |
| 278 | + if last_bad_batch_size is not None and actual_batch_size >= last_bad_batch_size: |
| 279 | + # A fast success at the upper limit means the last_bad_batch_size is stale. |
| 280 | + # We can resume unconstrained growth. |
| 281 | + last_bad_batch_size = None |
| 282 | + consecutive_fast_successes_near_upper_limit = 0 |
| 283 | + desired_batch_size = _next_batch_after_success( |
| 284 | + actual_batch_size, config, last_bad_batch_size |
| 285 | + ) |
| 286 | + elif _is_next_step_above_limit(actual_batch_size, last_bad_batch_size): |
| 287 | + # Once the gap is smaller than our useful search resolution, |
| 288 | + # hold the current known-good size and only re-probe the hint |
| 289 | + # after a few stable fast successes. |
| 290 | + consecutive_fast_successes_near_upper_limit += 1 |
| 291 | + if consecutive_fast_successes_near_upper_limit >= SOFT_UPPER_LIMIT_RETRY_AFTER_SUCCESSES: |
| 292 | + # We've had enough stable fast successes to re-probe the last_bad_batch_size. |
| 293 | + desired_batch_size = last_bad_batch_size |
| 294 | + consecutive_fast_successes_near_upper_limit = 0 |
| 295 | + else: |
| 296 | + # Hold current size for one more iteration |
| 297 | + desired_batch_size = actual_batch_size |
| 298 | + else: |
| 299 | + # Binary search or unconstrained growth upwards |
| 300 | + consecutive_fast_successes_near_upper_limit = 0 |
| 301 | + desired_batch_size = _next_batch_after_success( |
| 302 | + actual_batch_size, config, last_bad_batch_size |
| 303 | + ) |
| 304 | + else: |
| 305 | + # Binary search downwards due to a slow batch |
| 306 | + consecutive_fast_successes_near_upper_limit = 0 |
| 307 | + logger.debug( |
| 308 | + f"{self._progress_label} batch size {actual_batch_size} took {elapsed:.2f}s " |
| 309 | + f"(target {config.target_batch_time_seconds:.2f}s); shrinking." |
| 310 | + ) |
| 311 | + last_fast_batch_size, last_bad_batch_size = _update_bounds_after_bad_batch( |
| 312 | + actual_batch_size, last_fast_batch_size, last_bad_batch_size |
| 313 | + ) |
| 314 | + desired_batch_size = _next_batch_after_retryable_failure( |
| 315 | + actual_batch_size, config, last_fast_batch_size, last_bad_batch_size |
| 316 | + ) |
| 317 | + |
| 318 | + progress.update(total_task, completed=processed, total=processed, refresh=True) |
0 commit comments