Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ class DownloadableResultSettings:
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
download_timeout (int): Timeout for download requests. Default 60 secs.
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
speed_warning_threshold_mbps (float): Threshold in MB/s below which to log warning. Default 0.1 MB/s.
"""

is_lz4_compressed: bool
link_expiry_buffer_secs: int = 0
download_timeout: int = 60
max_consecutive_file_download_retries: int = 0
speed_warning_threshold_mbps: float = 0.1


class ResultSetDownloadHandler:
Expand Down Expand Up @@ -90,6 +92,8 @@ def run(self) -> DownloadedFile:
self.link, self.settings.link_expiry_buffer_secs
)

start_time = time.time()

with self._http_client.execute(
method=HttpMethod.GET,
url=self.link.fileLink,
Expand All @@ -102,6 +106,13 @@ def run(self) -> DownloadedFile:

# Save (and decompress if needed) the downloaded file
compressed_data = response.content

# Log download metrics
download_duration = time.time() - start_time
self._log_download_metrics(
self.link.fileLink, len(compressed_data), download_duration
)

decompressed_data = (
ResultSetDownloadHandler._decompress_data(compressed_data)
if self.settings.is_lz4_compressed
Expand All @@ -128,6 +139,30 @@ def run(self) -> DownloadedFile:
self.link.rowCount,
)

def _log_download_metrics(
self, url: str, bytes_downloaded: int, duration_seconds: float
):
"""Log download speed metrics at INFO/WARN levels."""
if duration_seconds <= 0:
return

# Calculate speed in MB/s (ensure float division for precision)
speed_mbps = (float(bytes_downloaded) / (1024 * 1024)) / duration_seconds

urlEndpoint = url.split("?")[0]
# INFO level logging
logger.info(
f"CloudFetch download completed: {speed_mbps:.4f} MB/s, "
f"{bytes_downloaded} bytes in {duration_seconds:.3f}s from {urlEndpoint}"
)

# WARN level logging if below threshold
if speed_mbps < self.settings.speed_warning_threshold_mbps:
logger.warning(
f"CloudFetch download slower than threshold: {speed_mbps:.4f} MB/s "
f"(threshold: {self.settings.speed_warning_threshold_mbps:.1f} MB/s) from {url}"
)

@staticmethod
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
"""
Expand Down
Loading