Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
84632f2
Add adaptive metadata upload batch sizing
guysmoilov Feb 23, 2026
ded5007
Handle non-retryable metadata upload failures
guysmoilov Mar 3, 2026
567c830
Allow retry shrink on partial metadata batches
guysmoilov Mar 3, 2026
4a767ab
Add retry backoff for metadata upload failures
guysmoilov Mar 3, 2026
0f5e9d7
Honor min batch size on retry failures
guysmoilov Mar 3, 2026
82e4503
Mock retry backoff sleep in upload tests
guysmoilov Mar 3, 2026
996b7fc
Avoid reusing known-bad upload batch size
guysmoilov Mar 3, 2026
2338c94
Extract metadata upload retry backoff constants
guysmoilov Mar 4, 2026
b884e9e
Align metadata upload backoff cap with schedule
guysmoilov Mar 4, 2026
4c1132c
Refactor adaptive metadata upload sizing and retries
guysmoilov Mar 5, 2026
6ae0a58
Fix adaptive upload expected batch sequence test
guysmoilov Mar 5, 2026
1b6356b
Extract AdaptiveBatcher into dagshub.common for reuse
guysmoilov Mar 10, 2026
3a944ae
Support unbounded iterables in AdaptiveBatcher
guysmoilov Mar 10, 2026
6cea20f
Format config.py line length (black)
guysmoilov Mar 10, 2026
a52a0aa
Fix review issues in AdaptiveBatcher
guysmoilov Mar 10, 2026
f6dfec9
Make batch growth factor and retry backoff configurable, add tests
guysmoilov Mar 10, 2026
349c25c
Fix batch size stall at bad_batch_size - 1, clear stale bounds
guysmoilov Mar 10, 2026
3e0ae68
Simplify batch sizing functions with clear strategy comments
guysmoilov Mar 10, 2026
3c1c817
Clear stale good/bad bounds when they become incoherent
guysmoilov Mar 10, 2026
232ebbf
Improve failure fallback convergence, add bad-bound clearing test
guysmoilov Mar 10, 2026
2c44d83
Fix retryable error check for wrapped exceptions in DataEngineGqlError
guysmoilov Mar 10, 2026
0453a33
Align metadata upload tests with adaptive batching behavior
guysmoilov Mar 10, 2026
00294fe
Handle tail-batch retry edge case and docstring mismatch
guysmoilov Mar 10, 2026
5cd02a7
Raise adaptive upload max default and clarify max indirection
guysmoilov Mar 11, 2026
dbf7a68
Use explicit max config name for adaptive upload sizing
guysmoilov Mar 11, 2026
53ca65a
Handle missing TransportConnectionFailed in supported gql versions
guysmoilov Mar 11, 2026
946134f
Review fixes
guysmoilov Mar 18, 2026
841972c
removed LEGACY_ nonsense
guysmoilov Mar 18, 2026
ebde20c
Refine adaptive batching search behavior
guysmoilov Mar 22, 2026
e07343c
Show adaptive batch size progress again
guysmoilov Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions dagshub/common/adaptive_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import logging
import time
from dataclasses import dataclass
import itertools
from types import SimpleNamespace
from typing import Callable, Iterable, List, Optional, Sized, TypeVar

import rich.progress
from tenacity import wait_exponential

from dagshub.common.rich_util import get_rich_progress

logger = logging.getLogger(__name__)

T = TypeVar("T")

MIN_TARGET_BATCH_TIME_SECONDS = 0.01


@dataclass(frozen=True)
class AdaptiveBatchConfig:
max_batch_size: int
min_batch_size: int
initial_batch_size: int
target_batch_time_seconds: float
batch_growth_factor: int
retry_backoff_base_seconds: float
retry_backoff_max_seconds: float

@classmethod
def from_values(
cls,
max_batch_size: Optional[int] = None,
min_batch_size: Optional[int] = None,
initial_batch_size: Optional[int] = None,
target_batch_time_seconds: Optional[float] = None,
batch_growth_factor: Optional[int] = None,
retry_backoff_base_seconds: Optional[float] = None,
retry_backoff_max_seconds: Optional[float] = None,
) -> "AdaptiveBatchConfig":
import dagshub.common.config as dgs_config

if max_batch_size is None:
max_batch_size = dgs_config.dataengine_metadata_upload_batch_size
if min_batch_size is None:
min_batch_size = dgs_config.dataengine_metadata_upload_batch_size_min
if initial_batch_size is None:
initial_batch_size = dgs_config.dataengine_metadata_upload_batch_size_initial
if target_batch_time_seconds is None:
target_batch_time_seconds = dgs_config.dataengine_metadata_upload_target_batch_time_seconds
if batch_growth_factor is None:
batch_growth_factor = dgs_config.adaptive_batch_growth_factor
if retry_backoff_base_seconds is None:
retry_backoff_base_seconds = dgs_config.adaptive_batch_retry_backoff_base_seconds
if retry_backoff_max_seconds is None:
retry_backoff_max_seconds = dgs_config.adaptive_batch_retry_backoff_max_seconds

normalized_max_batch_size = max(1, max_batch_size)
normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size))
normalized_initial_batch_size = max(
normalized_min_batch_size,
min(initial_batch_size, normalized_max_batch_size),
)
normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS)
return cls(
max_batch_size=normalized_max_batch_size,
min_batch_size=normalized_min_batch_size,
initial_batch_size=normalized_initial_batch_size,
target_batch_time_seconds=normalized_target_batch_time_seconds,
batch_growth_factor=max(2, batch_growth_factor),
retry_backoff_base_seconds=max(0.0, retry_backoff_base_seconds),
retry_backoff_max_seconds=max(0.0, retry_backoff_max_seconds),
)
Comment on lines +43 to +84
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could all be redone with a default_factory on each dataclass field:
https://docs.python.org/3/library/dataclasses.html#default-factory-functions

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default_factory would not cover what from_values() is doing here: reading runtime config plus normalizing and clamping related fields.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would it not cover it? It is literally just a function that runs at init-time.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because default_factory as far as I understand can't initialize fields based on a function of other fields - it's only a function to initialize each field individually. For interactions between fields that determine field values oyu need.... a constructor

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only value that actually is dependent on another value in that whole constructor is min_batch_size, everything else is derived from configuration values.

If you really want to clamp min_batch_size, you can do it in __post_init__(): https://docs.python.org/3/library/dataclasses.html#dataclasses.__post_init__



def _clamp(value: int, lo: int, hi: int) -> int:
return max(lo, min(hi, value))


def _next_batch_after_success(
batch_size: int,
config: AdaptiveBatchConfig,
bad_batch_size: Optional[int],
) -> int:
"""Pick the next batch size after a successful (fast) batch.

Strategy:
- If we have a known-bad upper bound, binary-search toward it.
- Otherwise, multiply by the growth factor.
- Always guarantee at least +1 progress (so we never stall).
"""
if bad_batch_size is not None and batch_size < bad_batch_size:
# Binary search: try the midpoint between current and bad
candidate = (batch_size + bad_batch_size) // 2
else:
# No upper bound (or we've already passed it): grow aggressively
candidate = batch_size * config.batch_growth_factor

# Must advance by at least 1 to avoid stalling
candidate = max(candidate, batch_size + 1)

return _clamp(candidate, config.min_batch_size, config.max_batch_size)


def _next_batch_after_retryable_failure(
batch_size: int,
config: AdaptiveBatchConfig,
good_batch_size: Optional[int],
bad_batch_size: Optional[int],
) -> int:
"""Pick the next batch size after a failed or slow batch.

Strategy:
- If we have a known-good lower bound, binary-search between it and the
failing size.
- Otherwise, halve.
- Must be strictly less than the current size (so we converge downward).
"""
if batch_size <= config.min_batch_size:
return config.min_batch_size

ceiling = batch_size - 1 # must shrink
if bad_batch_size is not None:
ceiling = min(ceiling, bad_batch_size - 1)

if good_batch_size is not None and good_batch_size < ceiling:
# Binary search: try the midpoint between good and failing
candidate = (good_batch_size + ceiling) // 2
else:
# No good lower bound — probe midpoint of the valid range
candidate = (config.min_batch_size + ceiling) // 2

return _clamp(candidate, config.min_batch_size, ceiling)


def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: AdaptiveBatchConfig) -> float:
# SimpleNamespace duck-types the .attempt_number attribute that tenacity's
# wait strategies read, avoiding the heavier RetryCallState constructor.
strategy = wait_exponential(
multiplier=config.retry_backoff_base_seconds,
min=config.retry_backoff_base_seconds,
max=config.retry_backoff_max_seconds,
)
retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures))
return float(strategy(retry_state)) # type: ignore[arg-type]


class AdaptiveBatcher:
"""Sends items in adaptively-sized batches, growing on success and shrinking on failure."""

def __init__(
self,
is_retryable: Callable[[Exception], bool],
config: Optional[AdaptiveBatchConfig] = None,
progress_label: str = "Uploading",
):
self._config = config if config is not None else AdaptiveBatchConfig.from_values()
self._is_retryable = is_retryable
self._progress_label = progress_label

def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None:
total: Optional[int] = len(items) if isinstance(items, Sized) else None
if total == 0:
return

config = self._config
current_batch_size = config.initial_batch_size
it = iter(items)
pending: List[T] = []

progress = get_rich_progress(rich.progress.MofNCompleteColumn())
total_task = progress.add_task(
f"{self._progress_label} (adaptive batch {config.min_batch_size}-{config.max_batch_size})...",
total=total,
)

last_good_batch_size: Optional[int] = None
last_bad_batch_size: Optional[int] = None
consecutive_retryable_failures = 0
processed = 0

with progress:
while True:
# Draw from pending (failed-batch leftovers) first, then the source iterator
batch = pending[:current_batch_size]
pending = pending[current_batch_size:]
if len(batch) < current_batch_size:
batch.extend(itertools.islice(it, current_batch_size - len(batch)))
if not batch:
break
batch_size = len(batch)

progress.update(
total_task,
description=f"{self._progress_label} (batch size {batch_size})...",
)
logger.debug(f"{self._progress_label}: {batch_size} entries...")

start_time = time.monotonic()
try:
operation(batch)
except Exception as exc:
if not self._is_retryable(exc):
logger.error(
f"{self._progress_label} failed with a non-retryable error; aborting.",
exc_info=True,
)
raise

if batch_size <= config.min_batch_size:
logger.error(
f"{self._progress_label} failed at minimum batch size ({batch_size}); aborting.",
exc_info=True,
)
raise

consecutive_retryable_failures += 1
time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures, config))

last_bad_batch_size = (
batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size)
)
if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size:
last_good_batch_size = None
current_batch_size = _next_batch_after_retryable_failure(
batch_size, config, last_good_batch_size, last_bad_batch_size
)
logger.warning(
f"{self._progress_label} failed for batch size {batch_size} "
f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}."
)
# Re-queue the failed batch items for retry with smaller batch size
pending = batch + pending
continue

elapsed = time.monotonic() - start_time
consecutive_retryable_failures = 0
processed += batch_size
progress.update(total_task, advance=batch_size)

if elapsed <= config.target_batch_time_seconds:
last_good_batch_size = (
batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size)
)
# Clear stale bad bound if we succeeded fast at or above it
if last_bad_batch_size is not None and batch_size >= last_bad_batch_size:
last_bad_batch_size = None
current_batch_size = _next_batch_after_success(batch_size, config, last_bad_batch_size)
else:
last_bad_batch_size = (
batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size)
)
if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size:
last_good_batch_size = None
current_batch_size = _next_batch_after_retryable_failure(
batch_size, config, last_good_batch_size, last_bad_batch_size
)

progress.update(total_task, completed=processed, total=processed, refresh=True)
30 changes: 29 additions & 1 deletion dagshub/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,35 @@ def set_host(new_host: str):
recommended_annotate_limit = int(os.environ.get(RECOMMENDED_ANNOTATE_LIMIT_KEY, 1e5))

DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE"
dataengine_metadata_upload_batch_size = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000))
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX"
dataengine_metadata_upload_batch_size = int(
os.environ.get(
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY,
os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000),
)
)

DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN"
dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 1))

DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_INITIAL"
dataengine_metadata_upload_batch_size_initial = int(
os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY, dataengine_metadata_upload_batch_size_min)
)

DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS"
dataengine_metadata_upload_target_batch_time_seconds = float(
os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, 5.0)
)

ADAPTIVE_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_ADAPTIVE_BATCH_GROWTH_FACTOR"
adaptive_batch_growth_factor = int(os.environ.get(ADAPTIVE_BATCH_GROWTH_FACTOR_KEY, 10))

ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE"
adaptive_batch_retry_backoff_base_seconds = float(os.environ.get(ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY, 0.25))

ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX"
adaptive_batch_retry_backoff_max_seconds = float(os.environ.get(ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY, 4.0))

DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS"
disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ
Expand Down
21 changes: 7 additions & 14 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Set, Tuple, Union

import rich.progress
from dataclasses_json import DataClassJsonMixin, LetterCase, config
from pathvalidate import sanitize_filepath

Expand Down Expand Up @@ -53,6 +52,8 @@
run_preupload_transforms,
validate_uploading_metadata,
)
from dagshub.common.adaptive_batching import AdaptiveBatcher
from dagshub.data_engine.model.metadata.util import is_retryable_metadata_upload_error
from dagshub.data_engine.model.metadata.dtypes import DatapointMetadataUpdateEntry
from dagshub.data_engine.model.metadata.transforms import DatasourceFieldInfo, _add_metadata
from dagshub.data_engine.model.metadata_field_builder import MetadataFieldBuilder
Expand Down Expand Up @@ -753,19 +754,11 @@ def _upload_metadata(self, metadata_entries: List[DatapointMetadataUpdateEntry])
validate_uploading_metadata(precalculated_info)
run_preupload_transforms(self, metadata_entries, precalculated_info)

progress = get_rich_progress(rich.progress.MofNCompleteColumn())

upload_batch_size = dagshub.common.config.dataengine_metadata_upload_batch_size
total_entries = len(metadata_entries)
total_task = progress.add_task(f"Uploading metadata (batch size {upload_batch_size})...", total=total_entries)

with progress:
for start in range(0, total_entries, upload_batch_size):
entries = metadata_entries[start : start + upload_batch_size]
logger.debug(f"Uploading {len(entries)} metadata entries...")
self.source.client.update_metadata(self, entries)
progress.update(total_task, advance=upload_batch_size)
progress.update(total_task, completed=total_entries, refresh=True)
batcher = AdaptiveBatcher(
is_retryable=is_retryable_metadata_upload_error,
progress_label="Uploading metadata",
)
batcher.run(metadata_entries, lambda batch: self.source.client.update_metadata(self, batch))

# Update the status from dagshub, so we get back the new metadata columns
self.source.get_from_dagshub()
Expand Down
21 changes: 21 additions & 0 deletions dagshub/data_engine/model/metadata/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import datetime
from gql.transport.exceptions import TransportServerError, TransportConnectionFailed
from requests import ConnectionError as RequestsConnectionError, Timeout as RequestsTimeout
from typing import Optional

from dagshub.data_engine.model.errors import DataEngineGqlError


def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]:
"""
Expand All @@ -19,3 +23,20 @@ def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]:
offset_minutes = int((offset.total_seconds() % 3600) // 60)
offset_str = f"{offset_hours:+03d}:{offset_minutes:02d}"
return offset_str


def is_retryable_metadata_upload_error(exc: Exception) -> bool:
if isinstance(exc, DataEngineGqlError) and isinstance(exc.original_exception, Exception):
return is_retryable_metadata_upload_error(exc.original_exception)

return isinstance(
exc,
(
TransportServerError,
TransportConnectionFailed,
TimeoutError,
ConnectionError,
RequestsConnectionError,
RequestsTimeout,
),
)
Loading