diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 9c0031211..9f5b3428e 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,11 @@ ### New Features and Improvements +* Add a new interface `upload_from` to `databricks.sdk.mixins.FilesExt` to support upload from a file in local filesystem. +* Improve `databricks.sdk.mixins.FilesExt` upload throughput by uploading data in parallel by default. +* Add a new interface `download_to` to `databricks.sdk.mixins.FilesExt` to support download to a file in local filesystem. This interface will also download the file in parallel by default. Parallel downloading is currently unavailable on Windows. +* Improve `databricks.sdk.mixins.FilesExt.upload` to support uploading when Presigned URL is not enabled for the Workspace by introducing a fallback to Single Part Upload. + ### Bug Fixes ### Documentation @@ -11,3 +16,14 @@ ### Internal Changes ### API Changes + +* Add `upload_from()`, `download_to()` method for `databricks.sdk.mixins.FilesExt`. +* Add `use_parallel`, `parallelism`, `part_size` field for `databricks.sdk.mixins.FilesExt.upload`. +* [Breaking] Change `files_api_client_download_max_total_recovers` to `files_ext_client_download_max_total_recovers` for `databricks.sdk.Config` +* [Breaking] Change `files_api_client_download_max_total_recovers_without_progressing` to `files_ext_client_download_max_total_recovers_without_progressing` for `databricks.sdk.Config` +* [Breaking] Change `multipart_upload_min_stream_size` to `files_ext_multipart_upload_min_stream_size` for `databricks.sdk.Config` +* [Breaking] Change `multipart_upload_batch_url_count` to `files_ext_multipart_upload_batch_url_count` for `databricks.sdk.Config` +* [Breaking] Change `multipart_upload_chunk_size` to `files_ext_multipart_upload_default_part_size` for `databricks.sdk.Config` +* [Breaking] Change `multipart_upload_url_expiration_duration` to `files_ext_multipart_upload_url_expiration_duration` for `databricks.sdk.Config` +* [Breaking] Change `multipart_upload_max_retries` to `files_ext_multipart_upload_max_retries` for `databricks.sdk.Config` +* Add `files_ext_client_download_streaming_chunk_size`, `files_ext_multipart_upload_part_size_options`, `files_ext_multipart_upload_max_part_size`, `files_ext_multipart_upload_default_parallelism`, `files_ext_presigned_download_url_expiration_duration`, `files_ext_parallel_download_default_parallelism`, `files_ext_parallel_download_min_file_size`, `files_ext_parallel_download_default_part_size`, `files_ext_parallel_download_max_retries` for `databricks.sdk.Config` diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 8f510ed2d..dd8b7f796 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -181,11 +181,7 @@ def _make_dbutils(config: client.Config): def _make_files_client(apiClient: client.ApiClient, config: client.Config): - if config.enable_experimental_files_api_client: - _LOG.info("Experimental Files API client is enabled") - return FilesExt(apiClient, config) - else: - return FilesAPI(apiClient) + return FilesExt(apiClient, config) class WorkspaceClient: @@ -603,11 +599,6 @@ def feature_store(self) -> pkg_ml.FeatureStoreAPI: """A feature store is a centralized repository that enables data scientists to find and share features.""" return self._feature_store - @property - def files(self) -> pkg_files.FilesAPI: - """The Files API is a standard HTTP API that allows you to read, write, list, and delete files and directories by referring to their URI.""" - return self._files - @property def functions(self) -> pkg_catalog.FunctionsAPI: """Functions implement User-Defined Functions (UDFs) in Unity Catalog.""" @@ -1013,6 +1004,11 @@ def users(self) -> pkg_iam.UsersAPI: """User identities recognized by Databricks and represented by email addresses.""" return self._users + @property + def files(self) -> FilesExt: + """The Files API is a standard HTTP API that allows you to read, write, list, and delete files and directories by referring to their URI.""" + return self._files + def get_workspace_id(self) -> int: """Get the workspace ID of the workspace that this client is connected to.""" response = self._api_client.do("GET", "/api/2.0/preview/scim/v2/Me", response_headers=["X-Databricks-Org-Id"]) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 4cfd8b4f9..879ba64ec 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -6,7 +6,7 @@ import pathlib import sys import urllib.parse -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, List, Optional import requests @@ -110,18 +110,27 @@ class Config: disable_async_token_refresh: bool = ConfigAttribute(env="DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH") - enable_experimental_files_api_client: bool = ConfigAttribute(env="DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT") - files_api_client_download_max_total_recovers = None - files_api_client_download_max_total_recovers_without_progressing = 1 + disable_experimental_files_api_client: bool = ConfigAttribute( + env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT" + ) + + files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB + + # When downloading a file, the maximum number of attempts to retry downloading the whole file. Default is no limit. + files_ext_client_download_max_total_recovers: Optional[int] = None - # File multipart upload parameters + # When downloading a file, the maximum number of attempts to retry downloading from the same offset without progressing. + # This is to avoid infinite retrying when the download is not making any progress. Default is 1. + files_ext_client_download_max_total_recovers_without_progressing = 1 + + # File multipart upload/download parameters # ---------------------- # Minimal input stream size (bytes) to use multipart / resumable uploads. # For small files it's more efficient to make one single-shot upload request. # When uploading a file, SDK will initially buffer this many bytes from input stream. # This parameter can be less or bigger than multipart_upload_chunk_size. - multipart_upload_min_stream_size: int = 5 * 1024 * 1024 + files_ext_multipart_upload_min_stream_size: int = 50 * 1024 * 1024 # Maximum number of presigned URLs that can be requested at a time. # @@ -131,23 +140,59 @@ class Config: # the stream back. In case of a non-seekable stream we cannot rewind, so we'll abort # the upload. To reduce the chance of this, we're requesting presigned URLs one by one # and using them immediately. - multipart_upload_batch_url_count: int = 1 + files_ext_multipart_upload_batch_url_count: int = 1 - # Size of the chunk to use for multipart uploads. + # Size of the chunk to use for multipart uploads & downloads. # # The smaller chunk is, the less chance for network errors (or URL get expired), # but the more requests we'll make. # For AWS, minimum is 5Mb: https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html # For GCP, minimum is 256 KiB (and also recommended multiple is 256 KiB) # boto uses 8Mb: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.TransferConfig - multipart_upload_chunk_size: int = 10 * 1024 * 1024 - - # use maximum duration of 1 hour - multipart_upload_url_expiration_duration: datetime.timedelta = datetime.timedelta(hours=1) + files_ext_multipart_upload_default_part_size: int = 10 * 1024 * 1024 # 10 MiB + + # List of multipart upload part sizes that can be automatically selected + files_ext_multipart_upload_part_size_options: List[int] = [ + 10 * 1024 * 1024, # 10 MiB + 20 * 1024 * 1024, # 20 MiB + 50 * 1024 * 1024, # 50 MiB + 100 * 1024 * 1024, # 100 MiB + 200 * 1024 * 1024, # 200 MiB + 500 * 1024 * 1024, # 500 MiB + 1 * 1024 * 1024 * 1024, # 1 GiB + 2 * 1024 * 1024 * 1024, # 2 GiB + 4 * 1024 * 1024 * 1024, # 4 GiB + ] + + # Maximum size of a single part in multipart upload. + # For AWS, maximum is 5 GiB: https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html + # For Azure, maximum is 4 GiB: https://learn.microsoft.com/en-us/rest/api/storageservices/put-block + # For CloudFlare R2, maximum is 5 GiB: https://developers.cloudflare.com/r2/objects/multipart-objects/ + files_ext_multipart_upload_max_part_size: int = 4 * 1024 * 1024 * 1024 # 4 GiB + + # Default parallel multipart upload concurrency. Set to 10 because of the experiment results show that it + # gives good performance result. + files_ext_multipart_upload_default_parallelism: int = 10 + + # The expiration duration for presigned URLs used in multipart uploads and downloads. + # The client will request new presigned URLs if the previous one is expired. The duration should be long enough + # to complete the upload or download of a single part. + files_ext_multipart_upload_url_expiration_duration: datetime.timedelta = datetime.timedelta(hours=1) + files_ext_presigned_download_url_expiration_duration: datetime.timedelta = datetime.timedelta(hours=1) + + # When downloading a file in parallel, how many worker threads to use. + files_ext_parallel_download_default_parallelism: int = 10 + + # When downloading a file, if the file size is smaller than this threshold, + # We'll use a single-threaded download even if the parallel download is enabled. + files_ext_parallel_download_min_file_size: int = 50 * 1024 * 1024 # 50 MiB + + # Default chunk size to use when downloading a file in parallel. Not effective for single threaded download. + files_ext_parallel_download_default_part_size: int = 10 * 1024 * 1024 # 10 MiB # This is not a "wall time" cutoff for the whole upload request, # but a maximum time between consecutive data reception events (even 1 byte) from the server - multipart_upload_single_chunk_upload_timeout_seconds: float = 60 + files_ext_network_transfer_inactivity_timeout_seconds: float = 60 # Cap on the number of custom retries during incremental uploads: # 1) multipart: upload part URL is expired, so new upload URLs must be requested to continue upload @@ -155,7 +200,10 @@ class Config: # retrieved to continue the upload. # In these two cases standard SDK retries (which are capped by the `retry_timeout_seconds` option) are not used. # Note that retry counter is reset when upload is successfully resumed. - multipart_upload_max_retries = 3 + files_ext_multipart_upload_max_retries = 3 + + # Cap on the number of custom retries during parallel downloads. + files_ext_parallel_download_max_retries = 3 def __init__( self, diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index fdf08a839..1d3b295ef 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -3,6 +3,7 @@ import base64 import datetime import logging +import math import os import pathlib import platform @@ -13,8 +14,13 @@ from abc import ABC, abstractmethod from collections import deque from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass from datetime import timedelta from io import BytesIO +from queue import Empty, Full, Queue +from tempfile import mkstemp +from threading import Event, Thread from types import TracebackType from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Callable, Generator, Iterable, Optional, Type, Union) @@ -27,12 +33,14 @@ from .._base_client import _BaseClient, _RawResponse, _StreamingResponse from .._property import _cached_property from ..config import Config -from ..errors import AlreadyExists, NotFound +from ..errors import AlreadyExists, NotFound, PermissionDenied from ..errors.mapper import _error_mapper from ..retries import retried from ..service import files from ..service._internal import _escape_multi_segment_path_parameter from ..service.files import DownloadResponse +from .files_utils import (CreateDownloadUrlResponse, _ConcatenatedInputStream, + _PresignedUrlDistributor) if TYPE_CHECKING: from _typeshed import Self @@ -710,18 +718,70 @@ def delete(self, path: str, *, recursive=False): p.delete(recursive=recursive) +class FallbackToUploadUsingFilesApi(Exception): + """Custom exception that signals to fallback to FilesAPI for upload""" + + def __init__(self, buffer, message): + super().__init__(message) + self.buffer = buffer + + +class FallbackToDownloadUsingFilesApi(Exception): + """Custom exception that signals to fallback to FilesAPI for download""" + + def __init__(self, message): + super().__init__(message) + + +@dataclass +class UploadStreamResult: + """Result of an upload from stream operation. Currently empty, but can be extended in the future.""" + + +@dataclass +class UploadFileResult: + """Result of an upload from file operation. Currently empty, but can be extended in the future.""" + + +@dataclass +class DownloadFileResult: + """Result of a download to file operation. Currently empty, but can be extended in the future.""" + + class FilesExt(files.FilesAPI): __doc__ = files.FilesAPI.__doc__ # note that these error codes are retryable only for idempotent operations - _RETRYABLE_STATUS_CODES = [408, 429, 500, 502, 503, 504] + _RETRYABLE_STATUS_CODES: list[int] = [408, 429, 500, 502, 503, 504] + + @dataclass(frozen=True) + class _UploadContext: + target_path: str + """The absolute remote path of the target file, e.g. /Volumes/path/to/your/file.""" + overwrite: Optional[bool] + """If true, an existing file will be overwritten. When unspecified, default behavior of the cloud storage provider is performed.""" + part_size: int + """The size of each part in bytes for multipart upload.""" + batch_size: int + """The number of urls to request in a single batch.""" + content_length: Optional[int] = None + """The total size of the content being uploaded, if known.""" + source_file_path: Optional[str] = None + """The local path of the file being uploaded, if applicable.""" + use_parallel: Optional[bool] = None + """If true, the upload will be performed using multiple threads.""" + parallelism: Optional[int] = None + """The number of threads to use for parallel upload, if applicable.""" def __init__(self, api_client, config: Config): super().__init__(api_client) self._config = config.copy() self._multipart_upload_read_ahead_bytes = 1 - def download(self, file_path: str) -> DownloadResponse: + def download( + self, + file_path: str, + ) -> DownloadResponse: """Download a file. Downloads a file of any size. The file contents are the response body. @@ -736,48 +796,462 @@ def download(self, file_path: str) -> DownloadResponse: :returns: :class:`DownloadResponse` """ + if self._config.disable_experimental_files_api_client: + _LOG.info("Disable experimental files API client, will use the original download method.") + return super().download(file_path) initial_response: DownloadResponse = self._open_download_stream( - file_path=file_path, - start_byte_offset=0, - if_unmodified_since_timestamp=None, + file_path=file_path, start_byte_offset=0, if_unmodified_since_timestamp=None ) wrapped_response = self._wrap_stream(file_path, initial_response) initial_response.contents._response = wrapped_response return initial_response - def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None): - """Upload a file. + def download_to( + self, + file_path: str, + destination: str, + *, + overwrite: bool = True, + use_parallel: bool = False, + parallelism: Optional[int] = None, + ) -> DownloadFileResult: + """Download a file to a local path. There would be no responses returned if the download is successful. + + :param file_path: str + The remote path of the file, e.g. /Volumes/path/to/your/file + :param destination: str + The local path where the file will be saved. + :param overwrite: bool + If true, an existing file will be overwritten. When not specified, assumed True. + :param use_parallel: bool + If true, the download will be performed using multiple threads. + :param parallelism: int + The number of parallel threads to use for downloading. If not specified, defaults to the number of CPU cores. + + :returns: :class:`DownloadFileResult` + """ + if self._config.disable_experimental_files_api_client: + raise NotImplementedError( + "Experimental files API features are disabled, download_to is not supported. Please use download instead." + ) + + # The existence of the target file is checked before starting the download. This is a best-effort check + # to avoid overwriting an existing file. However, there is nothing preventing a file from being created + # at the destination path after this check and before the file is written, and no way to prevent other + # actor from writing to the destination path concurrently. + if not overwrite and os.path.exists(destination): + raise FileExistsError(destination) + if use_parallel: + # Parallel download is not supported for Windows due to the limit of only one open file handle + # for writing. If parallel download is requested on Windows, fall back to sequential download with + # a warning. + if platform.system() == "Windows": + _LOG.warning("Parallel download is not supported on Windows. Falling back to sequential download.") + self._sequential_download_to_file(destination, remote_path=file_path) + return DownloadFileResult() + if parallelism is None: + parallelism = self._config.files_ext_parallel_download_default_parallelism + if parallelism < 1 or parallelism > 64: + raise ValueError("parallelism must be between 1 and 64") + self._parallel_download_with_fallback(file_path, destination, parallelism=parallelism) + else: + self._sequential_download_to_file(destination, remote_path=file_path) + return DownloadFileResult() + + def _parallel_download_with_fallback(self, remote_path: str, destination: str, parallelism: int) -> None: + """Download a file in parallel to a local path. There would be no responses returned if the download is successful. + This method first tries to use the Presigned URL for parallel download. If it fails due to permission issues, + it falls back to using Files API. + + :param remote_path: str + The remote path of the file, e.g. /Volumes/path/to/your/file + :param destination: str + The local path where the file will be saved. + :param parallelism: int + The number of parallel threads to use for downloading. + + :returns: None + """ + try: + self._parallel_download_presigned_url(remote_path, destination, parallelism) + except FallbackToDownloadUsingFilesApi as e: + _LOG.info("Falling back to Files API download due to permission issues with Presigned URL: %s", e) + self._parallel_download_files_api(remote_path, destination, parallelism) + + def _sequential_download_to_file( + self, destination: str, remote_path: str, last_modified: Optional[str] = None + ) -> None: + with open(destination, "wb") as f: + response = self._open_download_stream( + file_path=remote_path, + start_byte_offset=0, + if_unmodified_since_timestamp=last_modified, + ) + wrapped_response = self._wrap_stream(remote_path, response, 0) + response.contents._response = wrapped_response + shutil.copyfileobj(response.contents, f) + + def _do_parallel_download( + self, remote_path: str, destination: str, parallelism: int, download_chunk: Callable + ) -> None: + + file_info = self.get_metadata(remote_path) + file_size = file_info.content_length + last_modified = file_info.last_modified + # If the file is smaller than the threshold, do not use parallel download. + if file_size <= self._config.files_ext_parallel_download_min_file_size: + self._sequential_download_to_file(destination, remote_path, last_modified) + return + part_size = self._config.files_ext_parallel_download_default_part_size + part_count = int(math.ceil(file_size / part_size)) + + fd, temp_file = mkstemp() + # We are preallocate the file size to the same as the remote file to avoid seeking beyond the file size. + os.truncate(temp_file, file_size) + os.close(fd) + try: + aborted = Event() + + def wrapped_download_chunk(start: int, end: int, last_modified: Optional[str], temp_file: str) -> None: + if aborted.is_set(): + return + additional_headers = { + "Range": f"bytes={start}-{end}", + "If-Unmodified-Since": last_modified, + } + try: + contents = download_chunk(additional_headers) + with open(temp_file, "r+b") as f: + f.seek(start) + shutil.copyfileobj(contents, f) + except Exception as e: + aborted.set() + raise e + + with ThreadPoolExecutor(max_workers=parallelism) as executor: + futures = [] + # Start the threads to download parts of the file. + for i in range(part_count): + start = i * part_size + end = min(start + part_size - 1, file_size - 1) + futures.append(executor.submit(wrapped_download_chunk, start, end, last_modified, temp_file)) + + # Wait for all threads to complete and check for exceptions. + for future in as_completed(futures): + exception = future.exception() + if exception: + raise exception + # Finally, move the temp file to the destination. + shutil.move(temp_file, destination) + finally: + if os.path.exists(temp_file): + os.remove(temp_file) + + def _parallel_download_presigned_url(self, remote_path: str, destination: str, parallelism: int) -> None: + """Download a file in parallel to a local path. There would be no responses returned if the download is successful. + + :param remote_path: str + The remote path of the file, e.g. /Volumes/path/to/your/file + :param destination: str + The local path where the file will be saved. + :param parallelism: int + The number of parallel threads to use for downloading. + + :returns: None + """ + + cloud_session = self._create_cloud_provider_session() + url_distributor = _PresignedUrlDistributor(lambda: self._create_download_url(remote_path)) - Uploads a file. The file contents should be sent as the request body as raw bytes (an - octet stream); do not encode or otherwise modify the bytes before sending. The contents of the - resulting file will be exactly the bytes sent in the request body. If the request is successful, there - is no response body. + def download_chunk(additional_headers: dict[str, str]) -> BinaryIO: + retry_count = 0 + while retry_count < self._config.files_ext_parallel_download_max_retries: + url_and_header, version = url_distributor.get_url() + + headers = {**url_and_header.headers, **additional_headers} + + def get_content() -> requests.Response: + return cloud_session.get(url_and_header.url, headers=headers) + + raw_resp = self._retry_cloud_idempotent_operation(get_content) + + if FilesExt._is_url_expired_response(raw_resp): + _LOG.info("Presigned URL expired, fetching a new one.") + url_distributor.invalidate_url(version) + retry_count += 1 + continue + elif raw_resp.status_code == 403: + raise FallbackToDownloadUsingFilesApi("Received 403 Forbidden from presigned URL") + + raw_resp.raise_for_status() + return BytesIO(raw_resp.content) + raise ValueError("Exceeded maximum retries for downloading with presigned URL: URL expired too many times") + + self._do_parallel_download(remote_path, destination, parallelism, download_chunk) + + def _parallel_download_files_api(self, remote_path: str, destination: str, parallelism: int) -> None: + """Download a file in parallel to a local path using FilesAPI. There would be no responses returned if the download is successful. + + :param remote_path: str + The remote path of the file, e.g. /Volumes/path/to/your/file + :param destination: str + The local path where the file will be saved. + :param parallelism: int + The number of parallel threads to use for downloading. + + :returns: None + """ + + def download_chunk(additional_headers: dict[str, str]) -> BinaryIO: + raw_response: dict = self._api.do( + method="GET", + path=f"/api/2.0/fs/files{remote_path}", + headers=additional_headers, + raw=True, + ) + return raw_response["contents"] + + self._do_parallel_download(remote_path, destination, parallelism, download_chunk) + + def _get_optimized_performance_parameters_for_upload( + self, content_length: Optional[int], part_size_overwrite: Optional[int] + ) -> (int, int): + """Get optimized part size and batch size for upload based on content length and provided part size. + + Returns tuple of (part_size, batch_size). + """ + chosen_part_size = None + + # 1. decide on the part size + if part_size_overwrite is not None: # If a part size is provided, we use it directly after validation. + if part_size_overwrite > self._config.files_ext_multipart_upload_max_part_size: + raise ValueError( + f"Part size {part_size_overwrite} exceeds maximum allowed size {self._config.files_ext_multipart_upload_max_part_size} bytes." + ) + chosen_part_size = part_size_overwrite + _LOG.debug(f"Using provided part size: {chosen_part_size} bytes") + else: # If no part size is provided, we will optimize based on the content length. + if content_length is not None: + # Choosing the smallest part size that allows for a maximum of 100 parts. + for part_size in self._config.files_ext_multipart_upload_part_size_options: + part_num = (content_length + part_size - 1) // part_size + if part_num <= 100: + chosen_part_size = part_size + _LOG.debug( + f"Optimized part size for upload: {chosen_part_size} bytes for content length {content_length} bytes" + ) + break + if chosen_part_size is None: # If no part size was chosen, we default to the maximum allowed part size. + chosen_part_size = self._config.files_ext_multipart_upload_max_part_size + + # Use defaults if not determined yet + if chosen_part_size is None: + chosen_part_size = self._config.files_ext_multipart_upload_default_part_size + + # 2. decide on the batch size + if content_length is not None and chosen_part_size is not None: + part_num = (content_length + chosen_part_size - 1) // chosen_part_size + chosen_batch_size = int( + math.ceil(math.sqrt(part_num)) + ) # Using the square root of the number of parts as a heuristic for batch size. + else: + chosen_batch_size = self._config.files_ext_multipart_upload_batch_url_count + + return chosen_part_size, chosen_batch_size + + def upload( + self, + file_path: str, + content: BinaryIO, + *, + overwrite: Optional[bool] = None, + part_size: Optional[int] = None, + use_parallel: bool = True, + parallelism: Optional[int] = None, + ) -> UploadStreamResult: + """ + Upload a file with stream interface. + + :param file_path: str + The absolute remote path of the target file, e.g. /Volumes/path/to/your/file + :param content: BinaryIO + The contents of the file to upload. This must be a BinaryIO stream. + :param overwrite: bool (optional) + If true, an existing file will be overwritten. When not specified, assumed True. + :param part_size: int (optional) + If set, multipart upload will use the value as its size per uploading part. + :param use_parallel: bool (optional) + If true, the upload will be performed using multiple threads. Be aware that this will consume more memory + because multiple parts will be buffered in memory before being uploaded. The amount of memory used is proportional + to `parallelism * part_size`. + If false, the upload will be performed in a single thread. + Default is True. + :param parallelism: int (optional) + The number of threads to use for parallel uploads. This is only used if `use_parallel` is True. + + :returns: :class:`UploadStreamResult` + """ + + if self._config.disable_experimental_files_api_client: + _LOG.info("Disable experimental files API client, will use the original upload method.") + super().upload(file_path=file_path, contents=content, overwrite=overwrite) + return UploadStreamResult() + + _LOG.debug(f"Uploading file from BinaryIO stream") + if parallelism is not None and not use_parallel: + raise ValueError("parallelism can only be set if use_parallel is True") + if parallelism is None and use_parallel: + parallelism = self._config.files_ext_multipart_upload_default_parallelism + + # Determine content length if the stream is seekable + content_length = None + if content.seekable(): + _LOG.debug(f"Uploading using seekable mode") + # If the stream is seekable, we can read its size. + content.seek(0, os.SEEK_END) + content_length = content.tell() + content.seek(0) + + # Get optimized part size and batch size based on content length and provided part size + optimized_part_size, optimized_batch_size = self._get_optimized_performance_parameters_for_upload( + content_length, part_size + ) + + # Create context with all final parameters + ctx = self._UploadContext( + target_path=file_path, + overwrite=overwrite, + part_size=optimized_part_size, + batch_size=optimized_batch_size, + content_length=content_length, + use_parallel=use_parallel, + parallelism=parallelism, + ) + + _LOG.debug( + f"Upload context: part_size={ctx.part_size}, batch_size={ctx.batch_size}, content_length={ctx.content_length}" + ) + + if ctx.use_parallel: + self._parallel_upload_from_stream(ctx, content) + return UploadStreamResult() + elif ctx.content_length is not None: + self._upload_single_thread_with_known_size(ctx, content) + return UploadStreamResult() + else: + _LOG.debug(f"Uploading using non-seekable mode") + # If the stream is not seekable, we cannot determine its size. + # We will use a multipart upload. + _LOG.debug(f"Using multipart upload for non-seekable input stream of unknown size for file {file_path}") + self._single_thread_multipart_upload(ctx, content) + return UploadStreamResult() + + def upload_from( + self, + file_path: str, + source_path: str, + *, + overwrite: Optional[bool] = None, + part_size: Optional[int] = None, + use_parallel: bool = True, + parallelism: Optional[int] = None, + ) -> UploadFileResult: + """Upload a file directly from a local path. :param file_path: str The absolute remote path of the target file. - :param contents: BinaryIO + :param source_path: str + The local path of the file to upload. This must be a path to a local file. + :param part_size: int + The size of each part in bytes for multipart upload. This is a required parameter for multipart uploads. :param overwrite: bool (optional) If true, an existing file will be overwritten. When not specified, assumed True. - """ + :param use_parallel: bool (optional) + If true, the upload will be performed using multiple threads. Default is True. + :param parallelism: int (optional) + The number of threads to use for parallel uploads. This is only used if `use_parallel` is True. + If not specified, the default parallelism will be set to config.multipart_upload_default_parallelism - # Upload empty and small files with one-shot upload. - pre_read_buffer = contents.read(self._config.multipart_upload_min_stream_size) - if len(pre_read_buffer) < self._config.multipart_upload_min_stream_size: - _LOG.debug( - f"Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.multipart_upload_min_stream_size} bytes" + :returns: :class:`UploadFileResult` + """ + if self._config.disable_experimental_files_api_client: + raise NotImplementedError( + "Experimental files API features are disabled, upload_from is not supported. Please use upload instead." ) - return super().upload(file_path=file_path, contents=BytesIO(pre_read_buffer), overwrite=overwrite) + _LOG.debug(f"Uploading file from local path: {source_path}") + + if parallelism is not None and not use_parallel: + raise ValueError("parallelism can only be set if use_parallel is True") + if parallelism is None and use_parallel: + parallelism = self._config.files_ext_multipart_upload_default_parallelism + # Get the file size + file_size = os.path.getsize(source_path) + + # Get optimized part size and batch size based on content length and provided part size + optimized_part_size, optimized_batch_size = self._get_optimized_performance_parameters_for_upload( + file_size, part_size + ) + + # Create context with all final parameters + ctx = self._UploadContext( + target_path=file_path, + overwrite=overwrite, + part_size=optimized_part_size, + batch_size=optimized_batch_size, + content_length=file_size, + source_file_path=source_path, + use_parallel=use_parallel, + parallelism=parallelism, + ) + if ctx.use_parallel: + self._parallel_upload_from_file(ctx) + return UploadFileResult() + else: + with open(source_path, "rb") as f: + self._upload_single_thread_with_known_size(ctx, f) + return UploadFileResult() + + def _upload_single_thread_with_known_size(self, ctx: _UploadContext, contents: BinaryIO) -> None: + """Upload a file with a known size.""" + if ctx.content_length < self._config.files_ext_multipart_upload_min_stream_size: + _LOG.debug(f"Using single-shot upload for input stream of size {ctx.content_length} bytes") + return self._single_thread_single_shot_upload(ctx, contents) + else: + _LOG.debug(f"Using multipart upload for input stream of size {ctx.content_length} bytes") + return self._single_thread_multipart_upload(ctx, contents) + + def _single_thread_single_shot_upload(self, ctx: _UploadContext, contents: BinaryIO) -> None: + """Upload a file with a known size.""" + _LOG.debug(f"Using single-shot upload for input stream") + return super().upload(file_path=ctx.target_path, contents=contents, overwrite=ctx.overwrite) + + def _initiate_multipart_upload(self, ctx: _UploadContext) -> dict: + """Initiate a multipart upload and return the response.""" query = {"action": "initiate-upload"} - if overwrite is not None: - query["overwrite"] = overwrite + if ctx.overwrite is not None: + query["overwrite"] = ctx.overwrite # Method _api.do() takes care of retrying and will raise an exception in case of failure. initiate_upload_response = self._api.do( - "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", query=query + "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}", query=query ) + return initiate_upload_response + + def _single_thread_multipart_upload(self, ctx: _UploadContext, contents: BinaryIO) -> None: + + # Upload empty and small files with one-shot upload. + pre_read_buffer = contents.read(self._config.files_ext_multipart_upload_min_stream_size) + if len(pre_read_buffer) < self._config.files_ext_multipart_upload_min_stream_size: + _LOG.debug( + f"Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.files_ext_multipart_upload_min_stream_size} bytes" + ) + return self._single_thread_single_shot_upload(ctx, BytesIO(pre_read_buffer)) + + # Initiate the multipart upload. + initiate_upload_response = self._initiate_multipart_upload(ctx) if initiate_upload_response.get("multipart_upload"): cloud_provider_session = self._create_cloud_provider_session() @@ -786,37 +1260,451 @@ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool raise ValueError(f"Unexpected server response: {initiate_upload_response}") try: - self._perform_multipart_upload( - file_path, contents, session_token, pre_read_buffer, cloud_provider_session - ) + self._perform_multipart_upload(ctx, contents, session_token, pre_read_buffer, cloud_provider_session) + except FallbackToUploadUsingFilesApi as e: + try: + self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + except BaseException as ex: + # Ignore abort exceptions as it is a best-effort. + _LOG.warning(f"Failed to abort upload: {ex}") + + _LOG.info(f"Falling back to single-shot upload with Files API: {e}") + # Concatenate the buffered part and the rest of the stream. + full_stream = _ConcatenatedInputStream(BytesIO(e.buffer), contents) + return self._single_thread_single_shot_upload(ctx, full_stream) + except Exception as e: _LOG.info(f"Aborting multipart upload on error: {e}") try: - self._abort_multipart_upload(file_path, session_token, cloud_provider_session) + self._abort_multipart_upload(ctx, session_token, cloud_provider_session) except BaseException as ex: + # Ignore abort exceptions as it is a best-effort. _LOG.warning(f"Failed to abort upload: {ex}") - # ignore, abort is a best-effort finally: - # rethrow original exception + # Rethrow the original exception raise e from None elif initiate_upload_response.get("resumable_upload"): cloud_provider_session = self._create_cloud_provider_session() session_token = initiate_upload_response["resumable_upload"]["session_token"] - self._perform_resumable_upload( - file_path, contents, session_token, overwrite, pre_read_buffer, cloud_provider_session - ) + + try: + self._perform_resumable_upload(ctx, contents, session_token, pre_read_buffer, cloud_provider_session) + except FallbackToUploadUsingFilesApi as e: + _LOG.info(f"Falling back to single-shot upload with Files API: {e}") + # Concatenate the buffered part and the rest of the stream. + full_stream = _ConcatenatedInputStream(BytesIO(e.buffer), contents) + return self._single_thread_single_shot_upload(ctx, full_stream) else: raise ValueError(f"Unexpected server response: {initiate_upload_response}") + def _parallel_upload_from_stream(self, ctx: _UploadContext, contents: BinaryIO) -> None: + """ + Upload a stream using multipart upload with multiple threads. + This method is not implemented in this example, but it would typically + involve creating multiple threads to upload different parts of the stream concurrently. + """ + initiate_upload_response = self._initiate_multipart_upload(ctx) + + if initiate_upload_response.get("resumable_upload"): + _LOG.warning("GCP does not support parallel resumable uploads, falling back to single-threaded upload") + return self._single_thread_multipart_upload(ctx, contents) + elif initiate_upload_response.get("multipart_upload"): + session_token = initiate_upload_response["multipart_upload"].get("session_token") + cloud_provider_session = self._create_cloud_provider_session() + if not session_token: + raise ValueError(f"Unexpected server response: {initiate_upload_response}") + try: + self._parallel_multipart_upload_from_stream(ctx, session_token, contents, cloud_provider_session) + except FallbackToUploadUsingFilesApi as e: + try: + self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + except Exception as abort_ex: + _LOG.warning(f"Failed to abort upload: {abort_ex}") + _LOG.info(f"Falling back to single-shot upload with Files API: {e}") + # Concatenate the buffered part and the rest of the stream. + full_stream = _ConcatenatedInputStream(BytesIO(e.buffer), contents) + return self._single_thread_single_shot_upload(ctx, full_stream) + except Exception as e: + _LOG.info(f"Aborting multipart upload on error: {e}") + try: + self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + except Exception as abort_ex: + _LOG.warning(f"Failed to abort upload: {abort_ex}") + finally: + # Rethrow the original exception. + raise e from None + else: + raise ValueError(f"Unexpected server response: {initiate_upload_response}") + + def _parallel_upload_from_file( + self, + ctx: _UploadContext, + ) -> None: + """ + Upload a file using multipart upload with multiple threads. + This method is not implemented in this example, but it would typically + involve creating multiple threads to upload different parts of the file concurrently. + """ + + initiate_upload_response = self._initiate_multipart_upload(ctx) + + if initiate_upload_response.get("multipart_upload"): + cloud_provider_session = self._create_cloud_provider_session() + session_token = initiate_upload_response["multipart_upload"].get("session_token") + if not session_token: + raise ValueError(f"Unexpected server response: {initiate_upload_response}") + try: + self._parallel_multipart_upload_from_file(ctx, session_token) + except FallbackToUploadUsingFilesApi as e: + try: + self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + except Exception as abort_ex: + _LOG.warning(f"Failed to abort upload: {abort_ex}") + + _LOG.info(f"Falling back to single-shot upload with Files API: {e}") + # Concatenate the buffered part and the rest of the stream. + with open(ctx.source_file_path, "rb") as f: + return self._single_thread_single_shot_upload(ctx, f) + + except Exception as e: + _LOG.info(f"Aborting multipart upload on error: {e}") + try: + self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + except Exception as abort_ex: + _LOG.warning(f"Failed to abort upload: {abort_ex}") + finally: + # Rethrow the original exception. + raise e from None + + elif initiate_upload_response.get("resumable_upload"): + _LOG.warning("GCP does not support parallel resumable uploads, falling back to single-threaded upload") + with open(ctx.source_file_path, "rb") as f: + return self._upload_single_thread_with_known_size(ctx, f) + else: + raise ValueError(f"Unexpected server response: {initiate_upload_response}") + + @dataclass + class _MultipartUploadPart: + ctx: FilesExt._UploadContext + part_index: int + part_offset: int + part_size: int + session_token: str + + def _parallel_multipart_upload_from_file( + self, + ctx: _UploadContext, + session_token: str, + ) -> None: + # Calculate the number of parts. + file_size = os.path.getsize(ctx.source_file_path) + part_size = ctx.part_size + num_parts = (file_size + part_size - 1) // part_size + _LOG.debug(f"Uploading file of size {file_size} bytes in {num_parts} parts using {ctx.parallelism} threads") + + # Create queues and worker threads. + task_queue = Queue() + etags_result_queue = Queue() + exception_queue = Queue() + aborted = Event() + workers = [ + Thread(target=self._upload_file_consumer, args=(task_queue, etags_result_queue, exception_queue, aborted)) + for _ in range(ctx.parallelism) + ] + _LOG.debug(f"Starting {len(workers)} worker threads for parallel upload") + + # Enqueue all parts. Since the task queue is populated before starting the workers, we don't need to signal completion. + for part_index in range(1, num_parts + 1): + part_offset = (part_index - 1) * part_size + part_size = min(part_size, file_size - part_offset) + part = self._MultipartUploadPart(ctx, part_index, part_offset, part_size, session_token) + task_queue.put(part) + + # Start the worker threads for parallel upload. + for worker in workers: + worker.start() + + # Wait for all tasks to be processed. + for worker in workers: + worker.join() + + # Check for exceptions: if any worker encountered an exception, raise the first one. + if not exception_queue.empty(): + first_exception = exception_queue.get() + raise first_exception + + # Collect results from the etags queue. + etags: dict = {} + while not etags_result_queue.empty(): + part_number, etag = etags_result_queue.get() + etags[part_number] = etag + + self._complete_multipart_upload(ctx, etags, session_token) + + def _parallel_multipart_upload_from_stream( + self, + ctx: _UploadContext, + session_token: str, + content: BinaryIO, + cloud_provider_session: requests.Session, + ) -> None: + + task_queue = Queue(maxsize=ctx.parallelism) # Limit queue size to control memory usage + etags_result_queue = Queue() + exception_queue = Queue() + all_produced = Event() + aborted = Event() + + # Do the first part read ahead + pre_read_buffer = content.read(ctx.part_size) + if not pre_read_buffer: + self._complete_multipart_upload(ctx, {}, session_token) + return + try: + etag = self._do_upload_one_part( + ctx, cloud_provider_session, 1, 0, len(pre_read_buffer), session_token, BytesIO(pre_read_buffer) + ) + etags_result_queue.put((1, etag)) + except FallbackToUploadUsingFilesApi as e: + raise FallbackToUploadUsingFilesApi( + pre_read_buffer, "Falling back to single-shot upload with Files API" + ) from e + + if len(pre_read_buffer) < ctx.part_size: + self._complete_multipart_upload(ctx, {1: etag}, session_token) + return + + def producer() -> None: + part_index = 2 + part_size = ctx.part_size + while not aborted.is_set(): + part_content = content.read(part_size) + if not part_content: + break + part_offset = (part_index - 1) * part_size + part = self._MultipartUploadPart(ctx, part_index, part_offset, len(part_content), session_token) + while not aborted.is_set(): + try: + task_queue.put((part, part_content), timeout=0.1) + break + except Full: + continue + part_index += 1 + all_produced.set() + + producer_thread = Thread(target=producer) + consumers = [ + Thread( + target=self._upload_stream_consumer, + args=(task_queue, etags_result_queue, exception_queue, all_produced, aborted), + ) + for _ in range(ctx.parallelism) + ] + _LOG.debug(f"Starting {len(consumers)} worker threads for parallel upload") + # Start producer and consumer threads + producer_thread.start() + for consumer in consumers: + consumer.start() + + # Wait for producer to finish + _LOG.debug(f"threads started, waiting for producer to finish") + producer_thread.join() + # Wait for all tasks to be processed + _LOG.debug(f"producer finished, waiting for consumers to finish") + # task_queue.join() + for consumer in consumers: + consumer.join() + + # Check for exceptions: if any worker encountered an exception, raise the first one. + if not exception_queue.empty(): + first_exception = exception_queue.get() + raise first_exception + + # Collect results from the etags queue + etags: dict = {} + while not etags_result_queue.empty(): + part_number, etag = etags_result_queue.get() + etags[part_number] = etag + + self._complete_multipart_upload(ctx, etags, session_token) + + def _complete_multipart_upload(self, ctx, etags, session_token): + query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token} + headers = {"Content-Type": "application/json"} + body: dict = {} + parts = [] + for part_number, etag in sorted(etags.items()): + part = {"part_number": part_number, "etag": etag} + parts.append(part) + body["parts"] = parts + self._api.do( + "POST", + f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}", + query=query, + headers=headers, + body=body, + ) + + def _upload_file_consumer( + self, + task_queue: Queue[FilesExt._MultipartUploadPart], + etags_queue: Queue[tuple[int, str]], + exception_queue: Queue[Exception], + aborted: Event, + ) -> None: + cloud_provider_session = self._create_cloud_provider_session() + while not aborted.is_set(): + try: + part = task_queue.get(block=False) + except Empty: + # The task_queue was populated before the workers were started, so we can exit if it's empty. + break + + try: + with open(part.ctx.source_file_path, "rb") as f: + f.seek(part.part_offset, os.SEEK_SET) + part_content = BytesIO(f.read(part.part_size)) + etag = self._do_upload_one_part( + part.ctx, + cloud_provider_session, + part.part_index, + part.part_offset, + part.part_size, + part.session_token, + part_content, + ) + etags_queue.put((part.part_index, etag)) + except Exception as e: + aborted.set() + exception_queue.put(e) + finally: + task_queue.task_done() + + def _upload_stream_consumer( + self, + task_queue: Queue[tuple[FilesExt._MultipartUploadPart, bytes]], + etags_queue: Queue[tuple[int, str]], + exception_queue: Queue[Exception], + all_produced: Event, + aborted: Event, + ) -> None: + cloud_provider_session = self._create_cloud_provider_session() + while not aborted.is_set(): + try: + (part, content) = task_queue.get(block=False, timeout=0.1) + except Empty: + if all_produced.is_set(): + break # No more parts will be produced and the queue is empty + else: + continue + try: + etag = self._do_upload_one_part( + part.ctx, + cloud_provider_session, + part.part_index, + part.part_offset, + part.part_size, + part.session_token, + BytesIO(content), + ) + etags_queue.put((part.part_index, etag)) + except Exception as e: + aborted.set() + exception_queue.put(e) + finally: + task_queue.task_done() + + def _do_upload_one_part( + self, + ctx: _UploadContext, + cloud_provider_session: requests.Session, + part_index: int, + part_offset: int, + part_size: int, + session_token: str, + part_content: BinaryIO, + ) -> str: + retry_count = 0 + + # Try to upload the part, retrying if the upload URL expires. + while True: + body: dict = { + "path": ctx.target_path, + "session_token": session_token, + "start_part_number": part_index, + "count": 1, + "expire_time": self._get_upload_url_expire_time(), + } + + headers = {"Content-Type": "application/json"} + + # Requesting URLs for the same set of parts is an idempotent operation and is safe to retry. + try: + # The _api.do() method handles retries and will raise an exception in case of failure. + upload_part_urls_response = self._api.do( + "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body + ) + except PermissionDenied as e: + if self._is_presigned_urls_disabled_error(e): + raise FallbackToUploadUsingFilesApi(None, "Presigned URLs are disabled") + else: + raise e from None + + upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) + if len(upload_part_urls) == 0: + raise ValueError(f"Unexpected server response: {upload_part_urls_response}") + upload_part_url = upload_part_urls[0] + url = upload_part_url["url"] + required_headers = upload_part_url.get("headers", []) + assert part_index == upload_part_url["part_number"] + + headers: dict = {"Content-Type": "application/octet-stream"} + for h in required_headers: + headers[h["name"]] = h["value"] + + _LOG.debug(f"Uploading part {part_index}: [{part_offset}, {part_offset + part_size - 1}]") + + def rewind() -> None: + part_content.seek(0, os.SEEK_SET) + + def perform_upload() -> requests.Response: + return cloud_provider_session.request( + "PUT", + url, + headers=headers, + data=part_content, + timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, + ) + + upload_response = self._retry_cloud_idempotent_operation(perform_upload, rewind) + + if upload_response.status_code in (200, 201): + etag = upload_response.headers.get("ETag", "") + return etag + elif FilesExt._is_url_expired_response(upload_response): + if retry_count < self._config.files_ext_multipart_upload_max_retries: + retry_count += 1 + _LOG.debug("Upload URL expired, retrying...") + continue + else: + raise ValueError(f"Unsuccessful chunk upload: upload URL expired after {retry_count} retries") + elif upload_response.status_code == 403: + raise FallbackToUploadUsingFilesApi(None, f"Direct upload forbidden: {upload_response.content}") + else: + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + _LOG.warning(message) + mapped_error = _error_mapper(upload_response, {}) + raise mapped_error or ValueError(message) + def _perform_multipart_upload( self, - target_path: str, + ctx: _UploadContext, input_stream: BinaryIO, session_token: str, pre_read_buffer: bytes, cloud_provider_session: requests.Session, - ): + ) -> None: """ Performs multipart upload using presigned URLs on AWS and Azure: https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html @@ -832,7 +1720,7 @@ def _perform_multipart_upload( # AWS signed chunked upload: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-tune-upload-download-python#buffering-during-uploads - chunk_offset = 0 # used only for logging + chunk_offset = 0 # This buffer is expected to contain at least multipart_upload_chunk_size bytes. # Note that initially buffer can be bigger (from pre_read_buffer). @@ -842,37 +1730,43 @@ def _perform_multipart_upload( eof = False while not eof: # If needed, buffer the next chunk. - buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream) + buffer = FilesExt._fill_buffer(buffer, ctx.part_size, input_stream) if len(buffer) == 0: # End of stream, no need to request the next block of upload URLs. break _LOG.debug( - f"Multipart upload: requesting next {self._config.multipart_upload_batch_url_count} upload URLs starting from part {current_part_number}" + f"Multipart upload: requesting next {ctx.batch_size} upload URLs starting from part {current_part_number}" ) body: dict = { - "path": target_path, + "path": ctx.target_path, "session_token": session_token, "start_part_number": current_part_number, - "count": self._config.multipart_upload_batch_url_count, - "expire_time": self._get_url_expire_time(), + "count": ctx.batch_size, + "expire_time": self._get_upload_url_expire_time(), } headers = {"Content-Type": "application/json"} # Requesting URLs for the same set of parts is an idempotent operation, safe to retry. - # Method _api.do() takes care of retrying and will raise an exception in case of failure. - upload_part_urls_response = self._api.do( - "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body - ) + try: + # Method _api.do() takes care of retrying and will raise an exception in case of failure. + upload_part_urls_response = self._api.do( + "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body + ) + except PermissionDenied as e: + if chunk_offset == 0 and self._is_presigned_urls_disabled_error(e): + raise FallbackToUploadUsingFilesApi(buffer, "Presigned URLs are disabled") + else: + raise e from None upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) if len(upload_part_urls) == 0: raise ValueError(f"Unexpected server response: {upload_part_urls_response}") for upload_part_url in upload_part_urls: - buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream) + buffer = FilesExt._fill_buffer(buffer, ctx.part_size, input_stream) actual_buffer_length = len(buffer) if actual_buffer_length == 0: eof = True @@ -886,7 +1780,7 @@ def _perform_multipart_upload( for h in required_headers: headers[h["name"]] = h["value"] - actual_chunk_length = min(actual_buffer_length, self._config.multipart_upload_chunk_size) + actual_chunk_length = min(actual_buffer_length, ctx.part_size) _LOG.debug( f"Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]" ) @@ -902,7 +1796,7 @@ def perform(): url, headers=headers, data=chunk, - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, ) upload_response = self._retry_cloud_idempotent_operation(perform, rewind) @@ -922,7 +1816,7 @@ def perform(): retry_count = 0 elif FilesExt._is_url_expired_response(upload_response): - if retry_count < self._config.multipart_upload_max_retries: + if retry_count < self._config.files_ext_multipart_upload_max_retries: retry_count += 1 _LOG.debug("Upload URL expired") # Preserve the buffer so we'll upload the current part again using next upload URL @@ -930,6 +1824,13 @@ def perform(): # don't confuse user with unrelated "Permission denied" error. raise ValueError(f"Unsuccessful chunk upload: upload URL expired") + elif upload_response.status_code == 403 and chunk_offset == 0: + # We got 403 failure when uploading the very first chunk (we can't tell if it is Azure for sure yet). + # This might happen due to Azure firewall enabled for the customer bucket. + # Let's fallback to using Files API which might be allowlisted to upload, passing + # currently buffered (but not yet uploaded) part of the stream. + raise FallbackToUploadUsingFilesApi(buffer, f"Direct upload forbidden: {upload_response.content}") + else: message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" _LOG.warning(message) @@ -938,9 +1839,7 @@ def perform(): current_part_number += 1 - _LOG.debug( - f"Completing multipart upload after uploading {len(etags)} parts of up to {self._config.multipart_upload_chunk_size} bytes" - ) + _LOG.debug(f"Completing multipart upload after uploading {len(etags)} parts of up to {ctx.part_size} bytes") query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token} headers = {"Content-Type": "application/json"} @@ -957,14 +1856,14 @@ def perform(): # Method _api.do() takes care of retrying and will raise an exception in case of failure. self._api.do( "POST", - f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(target_path)}", + f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}", query=query, headers=headers, body=body, ) @staticmethod - def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO): + def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO) -> bytes: """ Tries to fill given buffer to contain at least `desired_min_size` bytes by reading from input stream. """ @@ -978,7 +1877,7 @@ def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO): return buffer @staticmethod - def _is_url_expired_response(response: requests.Response): + def _is_url_expired_response(response: requests.Response) -> bool: """ Checks if response matches one of the known "URL expired" responses from the cloud storage providers. """ @@ -1011,15 +1910,21 @@ def _is_url_expired_response(response: requests.Response): return False + def _is_presigned_urls_disabled_error(self, e: PermissionDenied) -> bool: + error_infos = e.get_error_info() + for error_info in error_infos: + if error_info.reason == "FILES_API_API_IS_NOT_ENABLED": + return True + return False + def _perform_resumable_upload( self, - target_path: str, + ctx: _UploadContext, input_stream: BinaryIO, session_token: str, - overwrite: bool, pre_read_buffer: bytes, cloud_provider_session: requests.Session, - ): + ) -> None: """ Performs resumable upload on GCP: https://cloud.google.com/storage/docs/performing-resumable-uploads """ @@ -1047,14 +1952,20 @@ def _perform_resumable_upload( # On the contrary, in multipart upload we can decide to complete upload *after* # last chunk has been sent. - body: dict = {"path": target_path, "session_token": session_token} + body: dict = {"path": ctx.target_path, "session_token": session_token} headers = {"Content-Type": "application/json"} - # Method _api.do() takes care of retrying and will raise an exception in case of failure. - resumable_upload_url_response = self._api.do( - "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body - ) + try: + # Method _api.do() takes care of retrying and will raise an exception in case of failure. + resumable_upload_url_response = self._api.do( + "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body + ) + except PermissionDenied as e: + if self._is_presigned_urls_disabled_error(e): + raise FallbackToUploadUsingFilesApi(pre_read_buffer, "Presigned URLs are disabled") + else: + raise e from None resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url") if not resumable_upload_url_node: @@ -1069,7 +1980,7 @@ def _perform_resumable_upload( try: # We will buffer this many bytes: one chunk + read-ahead block. # Note buffer may contain more data initially (from pre_read_buffer). - min_buffer_size = self._config.multipart_upload_chunk_size + self._multipart_upload_read_ahead_bytes + min_buffer_size = ctx.part_size + self._multipart_upload_read_ahead_bytes buffer = pre_read_buffer @@ -1094,7 +2005,7 @@ def _perform_resumable_upload( file_size = chunk_offset + actual_chunk_length else: # More chunks expected, let's upload current chunk (excluding read-ahead block). - actual_chunk_length = self._config.multipart_upload_chunk_size + actual_chunk_length = ctx.part_size file_size = "*" headers: dict = {"Content-Type": "application/octet-stream"} @@ -1113,7 +2024,7 @@ def perform(): resumable_upload_url, headers={"Content-Range": "bytes */*"}, data=b"", - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, ) try: @@ -1128,7 +2039,7 @@ def perform(): resumable_upload_url, headers=headers, data=BytesIO(buffer[:actual_chunk_length]), - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, ) # https://cloud.google.com/storage/docs/performing-resumable-uploads#resume-upload @@ -1136,9 +2047,8 @@ def perform(): # a 503 or 500 response, then you need to resume the interrupted upload from where it left off. # Let's follow that for all potentially retryable status codes. - # Together with the catch block below we replicate the logic in _retry_databricks_idempotent_operation(). if upload_response.status_code in self._RETRYABLE_STATUS_CODES: - if retry_count < self._config.multipart_upload_max_retries: + if retry_count < self._config.files_ext_multipart_upload_max_retries: retry_count += 1 # let original upload_response be handled as an error upload_response = retrieve_upload_status() or upload_response @@ -1148,7 +2058,10 @@ def perform(): except RequestException as e: # Let's do the same for retryable network errors. - if _BaseClient._is_retryable(e) and retry_count < self._config.multipart_upload_max_retries: + if ( + _BaseClient._is_retryable(e) + and retry_count < self._config.files_ext_multipart_upload_max_retries + ): retry_count += 1 upload_response = retrieve_upload_status() if not upload_response: @@ -1194,7 +2107,7 @@ def perform(): uploaded_bytes_count = next_chunk_offset - chunk_offset chunk_offset = next_chunk_offset - elif upload_response.status_code == 412 and not overwrite: + elif upload_response.status_code == 412 and not ctx.overwrite: # Assuming this is only possible reason # Full message in this case: "At least one of the pre-conditions you specified did not hold." raise AlreadyExists("The file being created already exists.") @@ -1227,19 +2140,38 @@ def _extract_range_offset(range_string: Optional[str]) -> Optional[int]: else: raise ValueError(f"Cannot parse response header: Range: {range_string}") - def _get_url_expire_time(self): - """Generates expiration time and save it in the required format.""" - current_time = datetime.datetime.now(datetime.timezone.utc) - expire_time = current_time + self._config.multipart_upload_url_expiration_duration + def _get_rfc339_timestamp_with_future_offset(self, base_time: datetime.datetime, offset: timedelta) -> str: + """Generates an offset timestamp in an RFC3339 format suitable for URL generation""" + offset_timestamp = base_time + offset # From Google Protobuf doc: # In JSON format, the Timestamp type is encoded as a string in the # * [RFC 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the # * format is "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z" - return expire_time.strftime("%Y-%m-%dT%H:%M:%SZ") + return offset_timestamp.strftime("%Y-%m-%dT%H:%M:%SZ") + + def _get_upload_url_expire_time(self) -> str: + """Generates expiration time in the required format.""" + current_time = datetime.datetime.now(datetime.timezone.utc) + return self._get_rfc339_timestamp_with_future_offset( + current_time, self._config.files_ext_multipart_upload_url_expiration_duration + ) + + def _get_download_url_expire_time(self) -> str: + """Generates expiration time in the required format.""" + current_time = datetime.datetime.now(datetime.timezone.utc) + return self._get_rfc339_timestamp_with_future_offset( + current_time, self._config.files_ext_presigned_download_url_expiration_duration + ) - def _abort_multipart_upload(self, target_path: str, session_token: str, cloud_provider_session: requests.Session): + def _abort_multipart_upload( + self, ctx: _UploadContext, session_token: str, cloud_provider_session: requests.Session + ) -> None: """Aborts ongoing multipart upload session to clean up incomplete file.""" - body: dict = {"path": target_path, "session_token": session_token, "expire_time": self._get_url_expire_time()} + body: dict = { + "path": ctx.target_path, + "session_token": session_token, + "expire_time": self._get_upload_url_expire_time(), + } headers = {"Content-Type": "application/json"} @@ -1254,13 +2186,13 @@ def _abort_multipart_upload(self, target_path: str, session_token: str, cloud_pr for h in required_headers: headers[h["name"]] = h["value"] - def perform(): + def perform() -> requests.Response: return cloud_provider_session.request( "DELETE", abort_url, headers=headers, data=b"", - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, ) abort_response = self._retry_cloud_idempotent_operation(perform) @@ -1270,19 +2202,19 @@ def perform(): def _abort_resumable_upload( self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session - ): + ) -> None: """Aborts ongoing resumable upload session to clean up incomplete file.""" headers: dict = {} for h in required_headers: headers[h["name"]] = h["value"] - def perform(): + def perform() -> requests.Response: return cloud_provider_session.request( "DELETE", resumable_upload_url, headers=headers, data=b"", - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, ) abort_response = self._retry_cloud_idempotent_operation(perform) @@ -1290,7 +2222,7 @@ def perform(): if abort_response.status_code not in (200, 201): raise ValueError(abort_response) - def _create_cloud_provider_session(self): + def _create_cloud_provider_session(self) -> requests.Session: """Creates a separate session which does not inherit auth headers from BaseClient session.""" session = requests.Session() @@ -1304,7 +2236,7 @@ def _create_cloud_provider_session(self): return session def _retry_cloud_idempotent_operation( - self, operation: Callable[[], requests.Response], before_retry: Callable = None + self, operation: Callable[[], requests.Response], before_retry: Optional[Callable] = None ) -> requests.Response: """Perform given idempotent operation with necessary retries for requests to non Databricks APIs. For cloud APIs, we will retry on network errors and on server response codes. @@ -1341,7 +2273,10 @@ def extended_is_retryable(e: BaseException) -> Optional[str]: )(delegate)() def _open_download_stream( - self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None + self, + file_path: str, + start_byte_offset: int, + if_unmodified_since_timestamp: Optional[str] = None, ) -> DownloadResponse: """Opens a download stream from given offset, performing necessary retries.""" headers = { @@ -1351,7 +2286,7 @@ def _open_download_stream( if start_byte_offset and not if_unmodified_since_timestamp: raise Exception("if_unmodified_since_timestamp is required if start_byte_offset is specified") - if start_byte_offset: + if start_byte_offset > 0: headers["Range"] = f"bytes={start_byte_offset}-" if if_unmodified_since_timestamp: @@ -1362,6 +2297,23 @@ def _open_download_stream( "content-type", "last-modified", ] + + result = self._init_download_response_mode_csp_with_fallback(file_path, headers, response_headers) + + if not isinstance(result.contents, _StreamingResponse): + raise Exception( + "Internal error: response contents is of unexpected type: " + type(result.contents).__name__ + ) + + return result + + def _init_download_response_files_api( + self, file_path: str, headers: dict[str, str], response_headers: list[str] + ) -> DownloadResponse: + """ + Initiates a download response using the Files API. + """ + # Method _api.do() takes care of retrying and will raise an exception in case of failure. res = self._api.do( "GET", @@ -1370,22 +2322,119 @@ def _open_download_stream( response_headers=response_headers, raw=True, ) + return DownloadResponse.from_dict(res) - result = DownloadResponse.from_dict(res) - if not isinstance(result.contents, _StreamingResponse): - raise Exception( - "Internal error: response contents is of unexpected type: " + type(result.contents).__name__ + def _create_download_url(self, file_path: str) -> CreateDownloadUrlResponse: + """ + Creates a presigned download URL using the CSP presigned URL API. + + Wrapped in similar retry logic to the internal API.do call: + 1. Call _.api.do to obtain the presigned URL + 2. Return the presigned URL + """ + + # Method _api.do() takes care of retrying and will raise an exception in case of failure. + try: + raw_response = self._api.do( + "POST", + f"/api/2.0/fs/create-download-url", + query={ + "path": file_path, + "expire_time": self._get_download_url_expire_time(), + }, ) - return result + return CreateDownloadUrlResponse.from_dict(raw_response) + except PermissionDenied as e: + if self._is_presigned_urls_disabled_error(e): + raise FallbackToDownloadUsingFilesApi(f"Presigned URLs are disabled") + else: + raise e from None + + def _init_download_response_presigned_api(self, file_path: str, added_headers: dict[str, str]) -> DownloadResponse: + """ + Initiates a download response using the CSP presigned URL API. + + Wrapped in similar retry logic to the internal API.do call: + 1. Call _.api.do to obtain the presigned URL + 2. Attempt to establish a streaming connection via the presigned URL + 3. Construct a StreamingResponse from the presigned URL + """ + + url_and_headers = self._create_download_url(file_path) + cloud_provider_session = self._create_cloud_provider_session() + + header_overlap = added_headers.keys() & url_and_headers.headers.keys() + if header_overlap: + raise ValueError( + f"Provided headers overlap with required headers from the CSP API bundle: {header_overlap}" + ) + + merged_headers = {**added_headers, **url_and_headers.headers} + + def perform() -> requests.Response: + return cloud_provider_session.request( + "GET", + url_and_headers.url, + headers=merged_headers, + timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, + stream=True, + ) + + csp_response: _RawResponse = self._retry_cloud_idempotent_operation(perform) - def _wrap_stream(self, file_path: str, download_response: DownloadResponse): + # Mapping the error if the response is not successful. + if csp_response.status_code in (200, 201, 206): + resp = DownloadResponse( + content_length=int(csp_response.headers.get("content-length")), + content_type=csp_response.headers.get("content-type"), + last_modified=csp_response.headers.get("last-modified"), + contents=_StreamingResponse(csp_response, self._config.files_ext_client_download_streaming_chunk_size), + ) + return resp + elif csp_response.status_code == 403: + # We got 403 failure when downloading the file. This might happen due to Azure firewall enabled for the customer bucket. + # Let's fallback to using Files API which might be allowlisted to download. + raise FallbackToDownloadUsingFilesApi(f"Direct download forbidden: {csp_response.content}") + else: + message = ( + f"Unsuccessful download. Response status: {csp_response.status_code}, body: {csp_response.content}" + ) + _LOG.warning(message) + mapped_error = _error_mapper(csp_response, {}) + raise mapped_error or ValueError(message) + + def _init_download_response_mode_csp_with_fallback( + self, file_path: str, headers: dict[str, str], response_headers: list[str] + ) -> DownloadResponse: + """ + Initiates a download response using the CSP presigned URL API or the Files API, depending on the configuration. + If the CSP presigned download API is enabled, it will attempt to use that first. + If the CSP API call fails, it will fall back to the Files API. + If the CSP presigned download API is disabled, it will use the Files API directly. + """ + + try: + _LOG.debug(f"Attempting download of {file_path} via CSP APIs") + return self._init_download_response_presigned_api(file_path, headers) + except FallbackToDownloadUsingFilesApi as e: + _LOG.info(f"Falling back to download via Files API: {e}") + _LOG.debug(f"Attempt via CSP APIs for {file_path} failed. Falling back to download via Files API") + ret = self._init_download_response_files_api(file_path, headers, response_headers) + return ret + + def _wrap_stream( + self, + file_path: str, + download_response: DownloadResponse, + start_byte_offset: int = 0, + ) -> "_ResilientResponse": underlying_response = _ResilientIterator._extract_raw_response(download_response) return _ResilientResponse( self, file_path, download_response.last_modified, - offset=0, + offset=start_byte_offset, underlying_response=underlying_response, ) @@ -1399,29 +2448,24 @@ def __init__( file_last_modified: str, offset: int, underlying_response: _RawResponse, - ): + ) -> None: self.api = api self.file_path = file_path self.underlying_response = underlying_response self.offset = offset self.file_last_modified = file_last_modified - def iter_content(self, chunk_size=1, decode_unicode=False): + def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> Iterator[bytes]: if decode_unicode: raise ValueError("Decode unicode is not supported") iterator = self.underlying_response.iter_content(chunk_size=chunk_size, decode_unicode=False) self.iterator = _ResilientIterator( - iterator, - self.file_path, - self.file_last_modified, - self.offset, - self.api, - chunk_size, + iterator, self.file_path, self.file_last_modified, self.offset, self.api, chunk_size ) return self.iterator - def close(self): + def close(self) -> None: self.iterator.close() @@ -1433,18 +2477,18 @@ class _ResilientIterator(Iterator): def _extract_raw_response( download_response: DownloadResponse, ) -> _RawResponse: - streaming_response: _StreamingResponse = download_response.contents # this is an instance of _StreamingResponse + streaming_response: _StreamingResponse = download_response.contents return streaming_response._response def __init__( self, - underlying_iterator, + underlying_iterator: Iterator[bytes], file_path: str, file_last_modified: str, offset: int, api: FilesExt, chunk_size: int, - ): + ) -> None: self._underlying_iterator = underlying_iterator self._api = api self._file_path = file_path @@ -1460,13 +2504,13 @@ def __init__( self._closed: bool = False def _should_recover(self) -> bool: - if self._total_recovers_count == self._api._config.files_api_client_download_max_total_recovers: + if self._total_recovers_count == self._api._config.files_ext_client_download_max_total_recovers: _LOG.debug("Total recovers limit exceeded") return False if ( - self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None + self._api._config.files_ext_client_download_max_total_recovers_without_progressing is not None and self._recovers_without_progressing_count - >= self._api._config.files_api_client_download_max_total_recovers_without_progressing + >= self._api._config.files_ext_client_download_max_total_recovers_without_progressing ): _LOG.debug("No progression recovers limit exceeded") return False @@ -1482,7 +2526,7 @@ def _recover(self) -> bool: try: self._underlying_iterator.close() - _LOG.debug("Trying to recover from offset " + str(self._offset)) + _LOG.debug(f"Trying to recover from offset {self._offset}") # following call includes all the required network retries downloadResponse = self._api._open_download_stream(self._file_path, self._offset, self._file_last_modified) @@ -1495,7 +2539,7 @@ def _recover(self) -> bool: except: return False # recover failed, rethrow original exception - def __next__(self): + def __next__(self) -> bytes: if self._closed: # following _BaseClient raise ValueError("I/O operation on closed file") @@ -1515,6 +2559,6 @@ def __next__(self): if not self._recover(): raise - def close(self): + def close(self) -> None: self._underlying_iterator.close() self._closed = True diff --git a/databricks/sdk/mixins/files_utils.py b/databricks/sdk/mixins/files_utils.py new file mode 100644 index 000000000..fd07bcdce --- /dev/null +++ b/databricks/sdk/mixins/files_utils.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import os +import threading +from dataclasses import dataclass +from typing import Any, BinaryIO, Callable, Iterable, Optional + + +@dataclass +class CreateDownloadUrlResponse: + """Response from the download URL API call.""" + + url: str + """The presigned URL to download the file.""" + headers: dict[str, str] + """Headers to use when making the download request.""" + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> CreateDownloadUrlResponse: + """Create an instance from a dictionary.""" + if "url" not in data: + raise ValueError("Missing 'url' in response data") + headers = data["headers"] if "headers" in data else {} + parsed_headers = {x["name"]: x["value"] for x in headers} + return cls(url=data["url"], headers=parsed_headers) + + +class _ConcatenatedInputStream(BinaryIO): + """This class joins two input streams into one.""" + + def __init__(self, head_stream: BinaryIO, tail_stream: BinaryIO): + if not head_stream.readable(): + raise ValueError("head_stream is not readable") + if not tail_stream.readable(): + raise ValueError("tail_stream is not readable") + + self._head_stream = head_stream + self._tail_stream = tail_stream + self._head_size = None + self._tail_size = None + + def close(self) -> None: + try: + self._head_stream.close() + finally: + self._tail_stream.close() + + def fileno(self) -> int: + raise AttributeError() + + def flush(self) -> None: + raise NotImplementedError("Stream is not writable") + + def isatty(self) -> bool: + raise NotImplementedError() + + def read(self, __n: int = -1) -> bytes: + head = self._head_stream.read(__n) + remaining_bytes = __n - len(head) if __n >= 0 else __n + tail = self._tail_stream.read(remaining_bytes) + return head + tail + + def readable(self) -> bool: + return True + + def readline(self, __limit: int = -1) -> bytes: + # Read and return one line from the stream. + # If __limit is specified, at most __limit bytes will be read. + # The line terminator is always b'\n' for binary files. + head = self._head_stream.readline(__limit) + if len(head) > 0 and head[-1:] == b"\n": + # end of line happened before (or at) the limit + return head + + # if __limit >= 0, len(head) can't exceed limit + remaining_bytes = __limit - len(head) if __limit >= 0 else __limit + tail = self._tail_stream.readline(remaining_bytes) + return head + tail + + def readlines(self, __hint: int = -1) -> list[bytes]: + # Read and return a list of lines from the stream. + # Hint can be specified to control the number of lines read: no more lines will be read + # If the total size (in bytes/characters) of all lines so far exceeds hint. + + # In fact, BytesIO(bytes) will not read next line if total size of all lines + # *equals or* exceeds hint. + + head_result = self._head_stream.readlines(__hint) + head_total_bytes = sum(len(line) for line in head_result) + + if 0 < __hint <= head_total_bytes and head_total_bytes > 0: + # We reached (or passed) the hint by reading from head_stream, or exhausted head_stream. + + if head_result[-1][-1:] == b"\n": + # If we reached/passed the hint and also stopped at the line break, return. + return head_result + + # Reading from head_stream could have stopped only because the stream was exhausted + if len(self._head_stream.read(1)) > 0: + raise ValueError( + f"Stream reading finished prematurely after reading {head_total_bytes} bytes, reaching or exceeding hint {__hint}" + ) + + # We need to finish reading the current line, now from tail_stream. + + tail_result = self._tail_stream.readlines(1) # We will only read the first line from tail_stream. + assert len(tail_result) <= 1 + if len(tail_result) > 0: + # We will then append the tail as the last line of the result. + return head_result[:-1] + [head_result[-1] + tail_result[0]] + else: + return head_result + + # We did not reach the hint by reading head_stream but exhausted it, continue reading from tail_stream + # with an adjusted hint + if __hint >= 0: + remaining_bytes = __hint - head_total_bytes + else: + remaining_bytes = __hint + + tail_result = self._tail_stream.readlines(remaining_bytes) + + if head_total_bytes > 0 and head_result[-1][-1:] != b"\n" and len(tail_result) > 0: + # If head stream does not end with the line break, we need to concatenate + # the last line of the head result and the first line of tail result + return head_result[:-1] + [head_result[-1] + tail_result[0]] + tail_result[1:] + else: + # Otherwise, just append two lists of lines. + return head_result + tail_result + + def _get_stream_size(self, stream: BinaryIO) -> int: + prev_offset = stream.tell() + try: + stream.seek(0, os.SEEK_END) + return stream.tell() + finally: + stream.seek(prev_offset, os.SEEK_SET) + + def _get_head_size(self) -> int: + if self._head_size is None: + self._head_size = self._get_stream_size(self._head_stream) + return self._head_size + + def _get_tail_size(self) -> int: + if self._tail_size is None: + self._tail_size = self._get_stream_size(self._tail_stream) + return self._tail_size + + def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int: + if not self.seekable(): + raise NotImplementedError("Stream is not seekable") + + if __whence == os.SEEK_SET: + if __offset < 0: + # Follow native buffer behavior + raise ValueError(f"Negative seek value: {__offset}") + + head_size = self._get_head_size() + + if __offset <= head_size: + self._head_stream.seek(__offset, os.SEEK_SET) + self._tail_stream.seek(0, os.SEEK_SET) + else: + self._head_stream.seek(0, os.SEEK_END) # move head stream to the end + self._tail_stream.seek(__offset - head_size, os.SEEK_SET) + + elif __whence == os.SEEK_CUR: + current_offset = self.tell() + new_offset = current_offset + __offset + if new_offset < 0: + # gracefully don't seek before start + new_offset = 0 + self.seek(new_offset, os.SEEK_SET) + + elif __whence == os.SEEK_END: + if __offset > 0: + # Python allows to seek beyond the end of stream. + + # Move head to EOF and tail to (EOF + offset), so subsequent tell() + # returns len(head) + len(tail) + offset, same as for native buffer + self._head_stream.seek(0, os.SEEK_END) + self._tail_stream.seek(__offset, os.SEEK_END) + else: + self._tail_stream.seek(__offset, os.SEEK_END) + tail_pos = self._tail_stream.tell() + if tail_pos > 0: + # target position lies within the tail, move head to EOF + self._head_stream.seek(0, os.SEEK_END) + else: + tail_size = self._get_tail_size() + self._head_stream.seek(__offset + tail_size, os.SEEK_END) + else: + raise ValueError(__whence) + return self.tell() + + def seekable(self) -> bool: + return self._head_stream.seekable() and self._tail_stream.seekable() + + def __getattribute__(self, name: str) -> Any: + if name == "fileno": + raise AttributeError() + elif name in ["tell", "seek"] and not self.seekable(): + raise AttributeError() + + return super().__getattribute__(name) + + def tell(self) -> int: + if not self.seekable(): + raise NotImplementedError() + + # Assuming that tail stream stays at 0 until head stream is exhausted + return self._head_stream.tell() + self._tail_stream.tell() + + def truncate(self, __size: Optional[int] = None) -> int: + raise NotImplementedError("Stream is not writable") + + def writable(self) -> bool: + return False + + def write(self, __s: bytes) -> int: + raise NotImplementedError("Stream is not writable") + + def writelines(self, __lines: Iterable[bytes]) -> None: + raise NotImplementedError("Stream is not writable") + + def __next__(self) -> bytes: + # IOBase [...] supports the iterator protocol, meaning that an IOBase object can be + # iterated over yielding the lines in a stream. [...] See readline(). + result = self.readline() + if len(result) == 0: + raise StopIteration + return result + + def __iter__(self) -> "BinaryIO": + return self + + def __enter__(self) -> "BinaryIO": + self._head_stream.__enter__() + self._tail_stream.__enter__() + return self + + def __exit__(self, __type, __value, __traceback) -> None: + self._head_stream.__exit__(__type, __value, __traceback) + self._tail_stream.__exit__(__type, __value, __traceback) + + def __str__(self) -> str: + return f"Concat: {self._head_stream}, {self._tail_stream}]" + + +class _PresignedUrlDistributor: + """ + Distributes and manages presigned URLs for downloading files. + + This class ensures thread-safe access to a presigned URL, allowing retrieval and invalidation. + When the URL is invalidated, a new one will be fetched using the provided function. + """ + + def __init__(self, get_new_url_func: Callable[[], CreateDownloadUrlResponse]): + """ + Initialize the distributor. + + Args: + get_new_url_func: A callable that returns a new presigned URL response. + """ + self._get_new_url_func = get_new_url_func + self._current_url = None + self.current_version = 0 + self.lock = threading.RLock() + + def get_url(self) -> tuple[CreateDownloadUrlResponse, int]: + """ + Get the current presigned URL and its version. + + Returns: + A tuple containing the current presigned URL response and its version. + """ + with self.lock: + if self._current_url is None: + self._current_url = self._get_new_url_func() + return self._current_url, self.current_version + + def invalidate_url(self, version: int) -> None: + """ + Invalidate the current presigned URL if the version matches. If the version does not match, + the URL remains unchanged. This ensures that only the most recent version can invalidate the URL. + + Args: + version: The version to check before invalidating the URL. + """ + with self.lock: + if version == self.current_version: + self._current_url = None + self.current_version += 1 diff --git a/docs/account/iam/workspace_assignment.rst b/docs/account/iam/workspace_assignment.rst index ca78b86df..133b16f3d 100644 --- a/docs/account/iam/workspace_assignment.rst +++ b/docs/account/iam/workspace_assignment.rst @@ -74,9 +74,9 @@ spn_id = spn.id - workspace_id = os.environ["TEST_WORKSPACE_ID"] + workspace_id = os.environ["DUMMY_WORKSPACE_ID"] - a.workspace_assignment.update( + _ = a.workspace_assignment.update( workspace_id=workspace_id, principal_id=spn_id, permissions=[iam.WorkspacePermission.USER], diff --git a/docs/account/provisioning/storage.rst b/docs/account/provisioning/storage.rst index 25ee5abaa..b9f080e36 100644 --- a/docs/account/provisioning/storage.rst +++ b/docs/account/provisioning/storage.rst @@ -16,6 +16,7 @@ .. code-block:: + import os import time from databricks.sdk import AccountClient @@ -25,8 +26,11 @@ storage = a.storage.create( storage_configuration_name=f"sdk-{time.time_ns()}", - root_bucket_info=provisioning.RootBucketInfo(bucket_name=f"sdk-{time.time_ns()}"), + root_bucket_info=provisioning.RootBucketInfo(bucket_name=os.environ["TEST_ROOT_BUCKET"]), ) + + # cleanup + a.storage.delete(storage_configuration_id=storage.storage_configuration_id) Creates a Databricks storage configuration for an account. diff --git a/docs/workspace/catalog/storage_credentials.rst b/docs/workspace/catalog/storage_credentials.rst index 48666f7ab..928ad39e5 100644 --- a/docs/workspace/catalog/storage_credentials.rst +++ b/docs/workspace/catalog/storage_credentials.rst @@ -32,11 +32,11 @@ created = w.storage_credentials.create( name=f"sdk-{time.time_ns()}", - aws_iam_role=catalog.AwsIamRoleRequest(role_arn=os.environ["TEST_METASTORE_DATA_ACCESS_ARN"]), + aws_iam_role=catalog.AwsIamRole(role_arn=os.environ["TEST_METASTORE_DATA_ACCESS_ARN"]), ) # cleanup - w.storage_credentials.delete(name=created.name) + w.storage_credentials.delete(delete=created.name) Creates a new storage credential. @@ -123,11 +123,10 @@ .. code-block:: from databricks.sdk import WorkspaceClient - from databricks.sdk.service import catalog w = WorkspaceClient() - all = w.storage_credentials.list(catalog.ListStorageCredentialsRequest()) + all = w.storage_credentials.list() Gets an array of storage credentials (as __StorageCredentialInfo__ objects). The array is limited to only those storage credentials the caller has permission to access. If the caller is a metastore diff --git a/docs/workspace/catalog/tables.rst b/docs/workspace/catalog/tables.rst index 75d4138fd..c5c3a131d 100644 --- a/docs/workspace/catalog/tables.rst +++ b/docs/workspace/catalog/tables.rst @@ -156,7 +156,7 @@ created_schema = w.schemas.create(name=f"sdk-{time.time_ns()}", catalog_name=created_catalog.name) - summaries = w.tables.list_summaries(catalog_name=created_catalog.name, schema_name_pattern=created_schema.name) + all_tables = w.tables.list(catalog_name=created_catalog.name, schema_name=created_schema.name) # cleanup w.schemas.delete(full_name=created_schema.full_name) diff --git a/docs/workspace/files/files.rst b/docs/workspace/files/files.rst index 6118d35e3..3d01566c6 100644 --- a/docs/workspace/files/files.rst +++ b/docs/workspace/files/files.rst @@ -2,7 +2,7 @@ ================== .. currentmodule:: databricks.sdk.service.files -.. py:class:: FilesAPI +.. py:class:: FilesExt The Files API is a standard HTTP API that allows you to read, write, list, and delete files and directories by referring to their URI. The API makes working with file content as raw bytes easier and @@ -61,15 +61,39 @@ .. py:method:: download(file_path: str) -> DownloadResponse - Downloads a file. The file contents are the response body. This is a standard HTTP file download, not - a JSON RPC. It supports the Range and If-Unmodified-Since HTTP headers. + Download a file. + + Downloads a file of any size. The file contents are the response body. + This is a standard HTTP file download, not a JSON RPC. + + It is strongly recommended, for fault tolerance reasons, + to iteratively consume from the stream with a maximum read(size) + defined instead of using indefinite-size reads. :param file_path: str - The absolute path of the file. + The remote path of the file, e.g. /Volumes/path/to/your/file :returns: :class:`DownloadResponse` + .. py:method:: download_to(file_path: str, destination: str [, overwrite: bool = True, use_parallel: bool = False, parallelism: Optional[int]]) -> DownloadFileResult + + Download a file to a local path. There would be no responses returned if the download is successful. + + :param file_path: str + The remote path of the file, e.g. /Volumes/path/to/your/file + :param destination: str + The local path where the file will be saved. + :param overwrite: bool + If true, an existing file will be overwritten. When not specified, assumed True. + :param use_parallel: bool + If true, the download will be performed using multiple threads. + :param parallelism: int + The number of parallel threads to use for downloading. If not specified, defaults to the number of CPU cores. + + :returns: :class:`DownloadFileResult` + + .. py:method:: get_directory_metadata(directory_path: str) Get the metadata of a directory. The response HTTP headers contain the metadata. There is no response @@ -124,19 +148,48 @@ :returns: Iterator over :class:`DirectoryEntry` - .. py:method:: upload(file_path: str, contents: BinaryIO [, overwrite: Optional[bool]]) + .. py:method:: upload(file_path: str, content: BinaryIO [, overwrite: Optional[bool], part_size: Optional[int], use_parallel: bool = True, parallelism: Optional[int]]) -> UploadStreamResult - Uploads a file of up to 5 GiB. The file contents should be sent as the request body as raw bytes (an - octet stream); do not encode or otherwise modify the bytes before sending. The contents of the - resulting file will be exactly the bytes sent in the request body. If the request is successful, there - is no response body. + + Upload a file with stream interface. :param file_path: str - The absolute path of the file. - :param contents: BinaryIO + The absolute remote path of the target file, e.g. /Volumes/path/to/your/file + :param content: BinaryIO + The contents of the file to upload. This must be a BinaryIO stream. :param overwrite: bool (optional) - If true or unspecified, an existing file will be overwritten. If false, an error will be returned if - the path points to an existing file. + If true, an existing file will be overwritten. When not specified, assumed True. + :param part_size: int (optional) + If set, multipart upload will use the value as its size per uploading part. + :param use_parallel: bool (optional) + If true, the upload will be performed using multiple threads. Be aware that this will consume more memory + because multiple parts will be buffered in memory before being uploaded. The amount of memory used is proportional + to `parallelism * part_size`. + If false, the upload will be performed in a single thread. + Default is True. + :param parallelism: int (optional) + The number of threads to use for parallel uploads. This is only used if `use_parallel` is True. + + :returns: :class:`UploadStreamResult` + + .. py:method:: upload_from(file_path: str, source_path: str [, overwrite: Optional[bool], part_size: Optional[int], use_parallel: bool = True, parallelism: Optional[int]]) -> UploadFileResult + Upload a file directly from a local path. + + :param file_path: str + The absolute remote path of the target file. + :param source_path: str + The local path of the file to upload. This must be a path to a local file. + :param part_size: int + The size of each part in bytes for multipart upload. This is a required parameter for multipart uploads. + :param overwrite: bool (optional) + If true, an existing file will be overwritten. When not specified, assumed True. + :param use_parallel: bool (optional) + If true, the upload will be performed using multiple threads. Default is True. + :param parallelism: int (optional) + The number of threads to use for parallel uploads. This is only used if `use_parallel` is True. + If not specified, the default parallelism will be set to config.multipart_upload_default_parallelism + + :returns: :class:`UploadFileResult` \ No newline at end of file diff --git a/docs/workspace/iam/permissions.rst b/docs/workspace/iam/permissions.rst index ea24afd1a..15524c53e 100644 --- a/docs/workspace/iam/permissions.rst +++ b/docs/workspace/iam/permissions.rst @@ -44,7 +44,7 @@ obj = w.workspace.get_status(path=notebook_path) - _ = w.permissions.get(request_object_type="notebooks", request_object_id="%d" % (obj.object_id)) + levels = w.permissions.get_permission_levels(request_object_type="notebooks", request_object_id="%d" % (obj.object_id)) Gets the permissions of an object. Objects can inherit permissions from their parent objects or root object. diff --git a/docs/workspace/ml/model_registry.rst b/docs/workspace/ml/model_registry.rst index 9a6c8f286..2d34256e4 100644 --- a/docs/workspace/ml/model_registry.rst +++ b/docs/workspace/ml/model_registry.rst @@ -120,7 +120,7 @@ model = w.model_registry.create_model(name=f"sdk-{time.time_ns()}") - created = w.model_registry.create_model_version(name=model.registered_model.name, source="dbfs:/tmp") + mv = w.model_registry.create_model_version(name=model.registered_model.name, source="dbfs:/tmp") Creates a model version. @@ -734,13 +734,14 @@ w = WorkspaceClient() - created = w.model_registry.create_model(name=f"sdk-{time.time_ns()}") + model = w.model_registry.create_model(name=f"sdk-{time.time_ns()}") - model = w.model_registry.get_model(name=created.registered_model.name) + created = w.model_registry.create_model_version(name=model.registered_model.name, source="dbfs:/tmp") - w.model_registry.update_model( - name=model.registered_model_databricks.name, + w.model_registry.update_model_version( description=f"sdk-{time.time_ns()}", + name=created.model_version.name, + version=created.model_version.version, ) Updates a registered model. diff --git a/docs/workspace/sharing/providers.rst b/docs/workspace/sharing/providers.rst index 1a7c88de9..fd81e1b24 100644 --- a/docs/workspace/sharing/providers.rst +++ b/docs/workspace/sharing/providers.rst @@ -101,25 +101,12 @@ .. code-block:: - import time - from databricks.sdk import WorkspaceClient + from databricks.sdk.service import sharing w = WorkspaceClient() - public_share_recipient = """{ - "shareCredentialsVersion":1, - "bearerToken":"dapiabcdefghijklmonpqrstuvwxyz", - "endpoint":"https://sharing.delta.io/delta-sharing/" - } - """ - - created = w.providers.create(name=f"sdk-{time.time_ns()}", recipient_profile_str=public_share_recipient) - - shares = w.providers.list_shares(name=created.name) - - # cleanup - w.providers.delete(name=created.name) + all = w.providers.list(sharing.ListProvidersRequest()) Gets an array of available authentication providers. The caller must either be a metastore admin or the owner of the providers. Providers not owned by the caller are not included in the response. There diff --git a/docs/workspace/sql/queries.rst b/docs/workspace/sql/queries.rst index 0dfb63fbf..f0081b3f2 100644 --- a/docs/workspace/sql/queries.rst +++ b/docs/workspace/sql/queries.rst @@ -29,7 +29,7 @@ display_name=f"sdk-{time.time_ns()}", warehouse_id=srcs[0].warehouse_id, description="test query from Go SDK", - query_text="SELECT 1", + query_text="SHOW TABLES", ) ) diff --git a/docs/workspace/workspace/workspace.rst b/docs/workspace/workspace/workspace.rst index fbcb5374b..4fba581e8 100644 --- a/docs/workspace/workspace/workspace.rst +++ b/docs/workspace/workspace/workspace.rst @@ -175,18 +175,11 @@ notebook_path = f"/Users/{w.current_user.me().user_name}/sdk-{time.time_ns()}" w.workspace.import_( - path=notebook_path, - overwrite=True, + content=base64.b64encode(("CREATE LIVE TABLE dlt_sample AS SELECT 1").encode()).decode(), format=workspace.ImportFormat.SOURCE, - language=workspace.Language.PYTHON, - content=base64.b64encode( - ( - """import time - time.sleep(10) - dbutils.notebook.exit('hello') - """ - ).encode() - ).decode(), + language=workspace.Language.SQL, + overwrite=True, + path=notebook_path, ) Imports a workspace object (for example, a notebook or file) or the contents of an entire directory. diff --git a/tests/test_files.py b/tests/test_files.py index d795c4649..0da24be83 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -4,13 +4,16 @@ import json import logging import os +import platform import random import re import time from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from tempfile import mkstemp -from typing import Callable, List, Optional, Type, Union +from enum import Enum +from tempfile import NamedTemporaryFile +from threading import Lock +from typing import Any, Callable, Dict, List, Optional, Type, Union from urllib.parse import parse_qs, urlparse import pytest @@ -20,21 +23,149 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.core import Config +from databricks.sdk.environments import Cloud, DatabricksEnvironment from databricks.sdk.errors.platform import (AlreadyExists, BadRequest, - InternalError, PermissionDenied, - TooManyRequests) + InternalError, NotImplemented, + PermissionDenied, TooManyRequests) +from databricks.sdk.mixins.files import FallbackToDownloadUsingFilesApi +from databricks.sdk.mixins.files_utils import CreateDownloadUrlResponse +from tests.clock import FakeClock + +from .test_files_utils import Utils logger = logging.getLogger(__name__) +class CustomResponse: + """Custom response allows to override the "default" response generated by the server + with the "custom" response to simulate failure error code, unexpected response body or + network error. + + The server is represented by the `processor` parameter in `generate_response()` call. + """ + + def __init__( + self, + # If False, default response is always returned. + # If True, response is defined by the current invocation count + # with respect to first_invocation / last_invocation / only_invocation + enabled: bool = True, + # Custom code to return + code: Optional[int] = 200, + # Custom body to return + body: Optional[str] = None, + # Custom exception to raise + exception: Optional[Type[BaseException]] = None, + # Whether exception should be raised before calling processor() + # (so changing server state) + exception_happened_before_processing: bool = False, + # First invocation (1-based) at which return custom response + first_invocation: Optional[int] = None, + # Last invocation (1-based) at which return custom response + last_invocation: Optional[int] = None, + # Only invocation (1-based) at which return custom response + only_invocation: Optional[int] = None, + # If set, response is delayed by given number of seconds + delayed_response_seconds: Optional[float] = None, + ): + self.enabled = enabled + self.code = code + self.body = body + self.exception = exception + self.exception_happened_before_processing = exception_happened_before_processing + self.first_invocation = first_invocation + self.last_invocation = last_invocation + self.only_invocation = only_invocation + self.delayed_response_seconds = delayed_response_seconds + + if self.only_invocation and (self.first_invocation or self.last_invocation): + raise ValueError("Cannot set both only invocation and first/last invocation") + + if self.exception_happened_before_processing and not self.exception: + raise ValueError("Exception is not defined") + + self.invocation_count = 0 + + def invocation_matches(self) -> bool: + if not self.enabled: + return False + + self.invocation_count += 1 + + if self.only_invocation: + return self.invocation_count == self.only_invocation + + if self.first_invocation and self.invocation_count < self.first_invocation: + return False + if self.last_invocation and self.invocation_count > self.last_invocation: + return False + return True + + def generate_response( + self, request: requests.Request, processor: Callable[[], list], stream=False + ) -> requests.Response: + if self.delayed_response_seconds: + time.sleep(self.delayed_response_seconds) + activate_for_current_invocation = self.invocation_matches() + + if activate_for_current_invocation and self.exception and self.exception_happened_before_processing: + # if network exception is thrown while processing a request, it's not defined + # if server actually processed the request (and so changed its state) + raise self.exception + + custom_response = [self.code, self.body or "", {}] + + if activate_for_current_invocation: + if self.code and 400 <= self.code < 500: + # if server returns client error, it's not supposed to change its state, + # so we're not calling processor() + [code, body_or_stream, headers] = custom_response + else: + # we're calling processor() but override its response with the custom one + processor() + [code, body_or_stream, headers] = custom_response + else: + [code, body_or_stream, headers] = processor() + + if activate_for_current_invocation and self.exception: + # self.exception_happened_before_processing is False + raise self.exception + + resp = requests.Response() + + resp.request = request + resp.status_code = code + if stream: + if type(body_or_stream) != bytes: + resp.raw = io.BytesIO(body_or_stream.encode()) + else: + resp.raw = io.BytesIO(body_or_stream) + else: + resp._content = body_or_stream.encode() + + for key in headers: + resp.headers[key] = headers[key] + + return resp + + def clear_state(self): + self.invocation_count = 0 + + @dataclass class RequestData: + offset: int + end_byte_offset: Optional[int] = None - def __init__(self, offset: int): - self._offset: int = offset +class DownloadMode(Enum): + """Download mode for the test case. Used to determine how to download the file.""" -class DownloadTestCase: + STREAM = "stream" # download to a stream (existing behavior) + FILE = "file" # download to a file (new download_to behavior) + + +class FilesApiDownloadTestCase: def __init__( self, @@ -42,10 +173,14 @@ def __init__( enable_new_client: bool, file_size: int, failure_at_absolute_offset: List[int], - max_recovers_total: Union[int, None], - max_recovers_without_progressing: Union[int, None], - expected_success: bool, - expected_requested_offsets: List[int], + max_recovers_total: Optional[int] = None, + max_recovers_without_progressing: Optional[int] = None, + expected_requested_offsets: Optional[List[int]] = None, + expected_exception: Optional[Type[BaseException]] = None, + download_mode: DownloadMode = DownloadMode.STREAM, + overwrite: bool = True, + use_parallel: bool = False, + parallelism: Optional[int] = None, ): self.name = name self.enable_new_client = enable_new_client @@ -53,47 +188,92 @@ def __init__( self.failure_at_absolute_offset = failure_at_absolute_offset self.max_recovers_total = max_recovers_total self.max_recovers_without_progressing = max_recovers_without_progressing - self.expected_success = expected_success + self.expected_exception = expected_exception self.expected_requested_offsets = expected_requested_offsets + self.download_mode = download_mode + self.overwrite = overwrite + self.use_parallel = use_parallel + self.parallelism = parallelism @staticmethod - def to_string(test_case: "DownloadTestCase") -> str: + def to_string(test_case: "FilesApiDownloadTestCase") -> str: return test_case.name - def run(self, config: Config) -> None: + def run(self, config: Config, monkeypatch) -> None: + if self.use_parallel and platform.system() == "Windows": + pytest.skip("Skipping parallel download tests on Windows") config = config.copy() - config.enable_experimental_files_api_client = self.enable_new_client - config.files_api_client_download_max_total_recovers = self.max_recovers_total - config.files_api_client_download_max_total_recovers_without_progressing = self.max_recovers_without_progressing + config.disable_experimental_files_api_client = not self.enable_new_client + config.files_ext_client_download_max_total_recovers = self.max_recovers_total + config.files_ext_client_download_max_total_recovers_without_progressing = self.max_recovers_without_progressing + config.enable_presigned_download_api = False w = WorkspaceClient(config=config) - session = MockSession(self) - w.files._api._api_client._session = session + session = MockFilesystemSession(self) + monkeypatch.setattr(w.files._api._api_client, "_session", session) - response = w.files.download("/test").contents - if self.expected_success: - actual_content = response.read() - assert len(actual_content) == len(session.content) - assert actual_content == session.content - else: - with pytest.raises(RequestException): - response.read() + if self.download_mode == DownloadMode.STREAM: + if self.expected_exception is None: + response = w.files.download("/test").contents + actual_content = response.read() + assert len(actual_content) == len(session.content) + assert actual_content == session.content + else: + with pytest.raises(self.expected_exception): + response = w.files.download("/test").contents + response.read() + elif self.download_mode == DownloadMode.FILE: # FILE mode + with NamedTemporaryFile(delete=False) as temp_file: + file_path = temp_file.name + + # We can't use 'with' because Windows doesn't allow reopening the file, and `download_to` can open it + # only after we close it here. + try: + if self.expected_exception is None: + w.files.download_to( + "/test", + file_path, + overwrite=self.overwrite, + use_parallel=self.use_parallel, + parallelism=self.parallelism, + ) + + # Verify the downloaded file content + with open(file_path, "rb") as f: + actual_content = f.read() + assert len(actual_content) == len(session.content) + assert actual_content == session.content + else: + with pytest.raises(self.expected_exception): + w.files.download_to( + "/test", + file_path, + overwrite=self.overwrite, + use_parallel=self.use_parallel, + parallelism=self.parallelism, + ) + finally: + if os.path.exists(file_path): + os.remove(file_path) received_requests = session.received_requests - assert len(self.expected_requested_offsets) == len(received_requests) - for idx, requested_offset in enumerate(self.expected_requested_offsets): - assert requested_offset == received_requests[idx]._offset + if self.expected_requested_offsets is not None: + assert len(received_requests) == len(self.expected_requested_offsets) + for idx, requested_offset in enumerate(self.expected_requested_offsets): + assert received_requests[idx].offset == requested_offset -class MockSession: +class MockFilesystemSession: - def __init__(self, test_case: DownloadTestCase): - self.test_case: DownloadTestCase = test_case + def __init__(self, test_case: FilesApiDownloadTestCase): + self.test_case: FilesApiDownloadTestCase = test_case self.received_requests: List[RequestData] = [] - self.content: bytes = os.urandom(self.test_case.file_size) + self.content: bytes = fast_random_bytes(self.test_case.file_size) self.failure_pointer = 0 + self.planned_failures = copy.deepcopy(self.test_case.failure_at_absolute_offset) + self.lock = Lock() self.last_modified = "Thu, 28 Nov 2024 16:39:14 GMT" # following the signature of Session.request() @@ -115,35 +295,52 @@ def request( verify=None, cert=None, json=None, - ) -> "MockResponse": - assert method == "GET" - assert stream == True + ) -> "MockFilesApiDownloadResponse": + + if method == "GET": + assert stream is True + return self._handle_get_file(headers, url) + elif method == "HEAD": + return self._handle_head_file(headers, url) + else: + raise FallbackToDownloadUsingFilesApi("method must be HEAD or GET") + def _handle_head_file(self, headers: Dict[str, str], url: str) -> "MockFilesApiDownloadResponse": + if "If-Unmodified-Since" in headers: + assert headers["If-Unmodified-Since"] == self.last_modified + resp = MockFilesApiDownloadResponse(self, 0, None, MockFilesApiDownloadRequest(url)) + resp.content = "" + return resp + + def _handle_get_file(self, headers: Dict[str, str], url: str) -> "MockFilesApiDownloadResponse": offset = 0 + end_byte_offset = None if "Range" in headers: - range = headers["Range"] - match = re.search("^bytes=(\\d+)-$", range) - if match: - offset = int(match.group(1)) - else: - raise Exception("Unexpected range header: " + range) - - if "If-Unmodified-Since" in headers: - assert headers["If-Unmodified-Since"] == self.last_modified - else: - raise Exception("If-Unmodified-Since header should be passed along with Range") + offset, end_byte_offset = Utils.parse_range_header(headers["Range"], len(self.content)) - logger.info("Client requested offset: %s", offset) + logger.debug("Client requested range: %s-%s", offset, end_byte_offset) if offset > len(self.content): raise Exception("Offset %s exceeds file length %s", offset, len(self.content)) + if end_byte_offset is not None and end_byte_offset >= len(self.content): + raise Exception("End offset %s exceeds file length %s", end_byte_offset, len(self.content)) + if end_byte_offset is not None and offset > end_byte_offset: + raise Exception("Begin offset %s exceeds end offset %s", offset, end_byte_offset) self.received_requests.append(RequestData(offset)) - return MockResponse(self, offset, MockRequest(url)) + return MockFilesApiDownloadResponse(self, offset, end_byte_offset, MockFilesApiDownloadRequest(url)) + + def get_content(self, offset: int, end_byte_offset: int) -> bytes: + with self.lock: + for failure_after_byte in self.planned_failures: + if offset <= failure_after_byte < end_byte_offset: + self.planned_failures.remove(failure_after_byte) + raise RequestException("Fake error") + return self.content[offset:end_byte_offset] # required only for correct logging -class MockRequest: +class MockFilesApiDownloadRequest: def __init__(self, url: str): self.url = url @@ -152,16 +349,25 @@ def __init__(self, url: str): self.body = None -class MockResponse: +class MockFilesApiDownloadResponse: - def __init__(self, session: MockSession, offset: int, request: MockRequest): + def __init__( + self, + session: MockFilesystemSession, + offset: int, + end_byte_offset: Optional[int], + request: MockFilesApiDownloadRequest, + ): self.session = session self.offset = offset + self.end_byte_offset = end_byte_offset self.request = request self.status_code = 200 self.reason = "OK" self.headers = dict() - self.headers["Content-Length"] = len(session.content) - offset + self.headers["Content-Length"] = ( + len(session.content) if end_byte_offset is None else end_byte_offset + 1 + ) - offset self.headers["Content-Type"] = "application/octet-stream" self.headers["Last-Modified"] = session.last_modified self.ok = True @@ -174,27 +380,25 @@ def iter_content(self, chunk_size: int, decode_unicode: bool) -> "MockIterator": class MockIterator: - def __init__(self, response: MockResponse, chunk_size: int): + def __init__(self, response: MockFilesApiDownloadResponse, chunk_size: int): self.response = response self.chunk_size = chunk_size self.offset = 0 def __next__(self) -> bytes: start_offset = self.response.offset + self.offset - if start_offset == len(self.response.session.content): - raise StopIteration - end_offset = start_offset + self.chunk_size # exclusive, might be out of range + if self.response.end_byte_offset is not None: + end_offset = min( + start_offset + self.chunk_size, self.response.end_byte_offset + 1 + ) # This is an exclusive index that might be out of range + else: + end_offset = start_offset + self.chunk_size # This is an exclusive index that might be out of range - if self.response.session.failure_pointer < len(self.response.session.test_case.failure_at_absolute_offset): - failure_after_byte = self.response.session.test_case.failure_at_absolute_offset[ - self.response.session.failure_pointer - ] - if failure_after_byte < end_offset: - self.response.session.failure_pointer += 1 - raise RequestException("Fake error") + if start_offset == len(self.response.session.content) or start_offset == end_offset: + raise StopIteration - result = self.response.session.content[start_offset:end_offset] + result = self.response.session.get_content(start_offset, end_offset) self.offset += len(result) return result @@ -209,88 +413,82 @@ class _Constants: @pytest.mark.parametrize( "test_case", [ - DownloadTestCase( - name="Old client: no failures, file of 5 bytes", + FilesApiDownloadTestCase( + name="Old files client: no failures, file of 5 bytes", enable_new_client=False, file_size=5, failure_at_absolute_offset=[], max_recovers_total=0, max_recovers_without_progressing=0, - expected_success=True, expected_requested_offsets=[0], ), - DownloadTestCase( - name="Old client: no failures, file of 1.5 chunks", + FilesApiDownloadTestCase( + name="Old files client: no failures, file of 1.5 chunks", enable_new_client=False, file_size=int(1.5 * _Constants.underlying_chunk_size), failure_at_absolute_offset=[], max_recovers_total=0, max_recovers_without_progressing=0, - expected_success=True, expected_requested_offsets=[0], ), - DownloadTestCase( - name="Old client: failure", + FilesApiDownloadTestCase( + name="Old files client: failure", enable_new_client=False, file_size=1024, failure_at_absolute_offset=[100], max_recovers_total=None, # unlimited but ignored max_recovers_without_progressing=None, # unlimited but ignored - expected_success=False, + expected_exception=RequestException, expected_requested_offsets=[0], ), - DownloadTestCase( - name="New client: no failures, file of 5 bytes", + FilesApiDownloadTestCase( + name="New files client: no failures, file of 5 bytes", enable_new_client=True, file_size=5, failure_at_absolute_offset=[], max_recovers_total=0, max_recovers_without_progressing=0, - expected_success=True, expected_requested_offsets=[0], ), - DownloadTestCase( - name="New client: no failures, file of 1 Kb", + FilesApiDownloadTestCase( + name="New files client: no failures, file of 1 Kb", enable_new_client=True, file_size=1024, max_recovers_total=None, max_recovers_without_progressing=None, failure_at_absolute_offset=[], - expected_success=True, expected_requested_offsets=[0], ), - DownloadTestCase( - name="New client: no failures, file of 1.5 chunks", + FilesApiDownloadTestCase( + name="New files client: no failures, file of 1.5 parts", enable_new_client=True, file_size=int(1.5 * _Constants.underlying_chunk_size), failure_at_absolute_offset=[], max_recovers_total=0, max_recovers_without_progressing=0, - expected_success=True, expected_requested_offsets=[0], ), - DownloadTestCase( - name="New client: no failures, file of 10 parts", + FilesApiDownloadTestCase( + name="New files client: no failures, file of 10 parts", enable_new_client=True, file_size=10 * _Constants.underlying_chunk_size, failure_at_absolute_offset=[], max_recovers_total=0, max_recovers_without_progressing=0, - expected_success=True, expected_requested_offsets=[0], ), - DownloadTestCase( - name="New client: recovers are disabled, first failure leads to download abort", + FilesApiDownloadTestCase( + name="New files client: recovers are disabled, first failure leads to download abort", enable_new_client=True, file_size=10000, failure_at_absolute_offset=[5], max_recovers_total=0, max_recovers_without_progressing=0, - expected_success=False, + expected_exception=RequestException, expected_requested_offsets=[0], ), - DownloadTestCase( - name="New client: unlimited recovers allowed", + FilesApiDownloadTestCase( + name="New files client: unlimited recovers allowed", enable_new_client=True, file_size=_Constants.underlying_chunk_size * 5, # causes errors on requesting the third chunk @@ -303,7 +501,6 @@ class _Constants: ], max_recovers_total=None, max_recovers_without_progressing=None, - expected_success=True, expected_requested_offsets=[ 0, 0, @@ -313,8 +510,8 @@ class _Constants: _Constants.underlying_chunk_size * 3, ], ), - DownloadTestCase( - name="New client: we respect limit on total recovers when progressing", + FilesApiDownloadTestCase( + name="New files client: we respect limit on total recovers when progressing", enable_new_client=True, file_size=_Constants.underlying_chunk_size * 10, failure_at_absolute_offset=[ @@ -325,7 +522,7 @@ class _Constants: ], max_recovers_total=3, max_recovers_without_progressing=None, - expected_success=False, + expected_exception=RequestException, expected_requested_offsets=[ 0, 0, @@ -333,18 +530,18 @@ class _Constants: _Constants.underlying_chunk_size * 2, ], ), - DownloadTestCase( - name="New client: we respect limit on total recovers when not progressing", + FilesApiDownloadTestCase( + name="New files client: we respect limit on total recovers when not progressing", enable_new_client=True, file_size=_Constants.underlying_chunk_size * 10, failure_at_absolute_offset=[1, 1, 1, 1], max_recovers_total=3, max_recovers_without_progressing=None, - expected_success=False, + expected_exception=RequestException, expected_requested_offsets=[0, 0, 0, 0], ), - DownloadTestCase( - name="New client: we respect limit on non-progressing recovers", + FilesApiDownloadTestCase( + name="New files client: we respect limit on non-progressing recovers", enable_new_client=True, file_size=_Constants.underlying_chunk_size * 2, failure_at_absolute_offset=[ @@ -355,11 +552,11 @@ class _Constants: ], max_recovers_total=None, max_recovers_without_progressing=3, - expected_success=False, + expected_exception=RequestException, expected_requested_offsets=[0, 0, 0, 0], ), - DownloadTestCase( - name="New client: non-progressing recovers count is reset when progressing", + FilesApiDownloadTestCase( + name="New files client: non-progressing recovers count is reset when progressing", enable_new_client=True, file_size=_Constants.underlying_chunk_size * 10, failure_at_absolute_offset=[ @@ -371,7 +568,7 @@ class _Constants: ], max_recovers_total=None, max_recovers_without_progressing=2, - expected_success=False, + expected_exception=RequestException, expected_requested_offsets=[ 0, _Constants.underlying_chunk_size, @@ -380,8 +577,8 @@ class _Constants: _Constants.underlying_chunk_size * 2, ], ), - DownloadTestCase( - name="New client: non-progressing recovers count is reset when progressing - 2", + FilesApiDownloadTestCase( + name="New files client: non-progressing recovers count is reset when progressing - 2", enable_new_client=True, file_size=_Constants.underlying_chunk_size * 10, failure_at_absolute_offset=[ @@ -392,7 +589,6 @@ class _Constants: ], max_recovers_total=None, max_recovers_without_progressing=1, - expected_success=True, expected_requested_offsets=[ 0, 0, @@ -401,11 +597,518 @@ class _Constants: _Constants.underlying_chunk_size * 3, ], ), + # Test cases for download_to functionality + FilesApiDownloadTestCase( + name="Download to file: New files client, no failures, file of 1 Kb", + enable_new_client=True, + file_size=1024, + max_recovers_total=None, + max_recovers_without_progressing=None, + failure_at_absolute_offset=[], + expected_requested_offsets=[0], + download_mode=DownloadMode.FILE, + ), + FilesApiDownloadTestCase( + name="Download to file in parallel (1 thread): New files client, no failures, file of 1 Kb", + enable_new_client=True, + file_size=1024, + max_recovers_total=None, + max_recovers_without_progressing=None, + failure_at_absolute_offset=[], + expected_requested_offsets=[0], + download_mode=DownloadMode.FILE, + use_parallel=True, + parallelism=1, + ), + FilesApiDownloadTestCase( + name="Download to file in parallel (4 threads): New files client, no failures, file of 1 Kb", + enable_new_client=True, + file_size=1024, + max_recovers_total=None, + max_recovers_without_progressing=None, + failure_at_absolute_offset=[], + download_mode=DownloadMode.FILE, + use_parallel=True, + parallelism=4, + ), + FilesApiDownloadTestCase( + name="Download to file: New files client, no failures, file of 10 parts", + enable_new_client=True, + file_size=10 * _Constants.underlying_chunk_size, + failure_at_absolute_offset=[], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_requested_offsets=[0], + download_mode=DownloadMode.FILE, + ), + FilesApiDownloadTestCase( + name="Download to file in parallel: New files client, no failures, file of 10 parts", + enable_new_client=True, + file_size=10 * _Constants.underlying_chunk_size, + failure_at_absolute_offset=[], + max_recovers_total=0, + max_recovers_without_progressing=0, + download_mode=DownloadMode.FILE, + use_parallel=True, + ), + FilesApiDownloadTestCase( + name="Download to file: New files client, failure with recovery", + enable_new_client=True, + file_size=_Constants.underlying_chunk_size * 5, + failure_at_absolute_offset=[ + _Constants.underlying_chunk_size - 1, + _Constants.underlying_chunk_size + 1, + _Constants.underlying_chunk_size * 3, + ], + max_recovers_total=None, + max_recovers_without_progressing=None, + expected_requested_offsets=[ + 0, + 0, + _Constants.underlying_chunk_size, + _Constants.underlying_chunk_size * 3, + ], + download_mode=DownloadMode.FILE, + ), + FilesApiDownloadTestCase( + name="Download to file in parallel: New files client, failure with recovery", + enable_new_client=True, + file_size=_Constants.underlying_chunk_size * 5, + failure_at_absolute_offset=[ + _Constants.underlying_chunk_size - 1, + _Constants.underlying_chunk_size + 1, + _Constants.underlying_chunk_size * 3, + ], + max_recovers_total=None, + max_recovers_without_progressing=None, + download_mode=DownloadMode.FILE, + use_parallel=True, + parallelism=2, + ), + FilesApiDownloadTestCase( + name="Download to file: New files client, failure without recovery", + enable_new_client=True, + file_size=10000, + failure_at_absolute_offset=[5], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_exception=RequestException, + expected_requested_offsets=[0], + download_mode=DownloadMode.FILE, + ), + FilesApiDownloadTestCase( + name="Download to file in parallel: New files client, failure without recovery", + enable_new_client=True, + file_size=10000, + failure_at_absolute_offset=[5], + max_recovers_total=0, + max_recovers_without_progressing=0, + expected_exception=RequestException, + download_mode=DownloadMode.FILE, + use_parallel=True, + ), + FilesApiDownloadTestCase( + name="Download to file: New files client, overwrite = False", + enable_new_client=True, + file_size=100, + failure_at_absolute_offset=[5], + expected_exception=IOError, + download_mode=DownloadMode.FILE, + overwrite=False, + ), + FilesApiDownloadTestCase( + name="Download to file in parallel: New files client, overwrite = False", + enable_new_client=True, + file_size=100, + failure_at_absolute_offset=[5], + expected_exception=IOError, + download_mode=DownloadMode.FILE, + overwrite=False, + use_parallel=True, + ), ], - ids=DownloadTestCase.to_string, + ids=FilesApiDownloadTestCase.to_string, ) -def test_download_recover(config: Config, test_case: DownloadTestCase) -> None: - test_case.run(config) +def test_download_recover(config: Config, test_case: FilesApiDownloadTestCase, monkeypatch): + test_case.run(config, monkeypatch) + + +class PresignedUrlDownloadServerState: + HOSTNAME = "mock-presigned-url.com" + + def __init__(self, file_size: int, last_modified: str): + self.file_size = file_size + self.content = fast_random_bytes(file_size) + self.requested = False + self.api_used: Optional[str] = None + self.last_modified = last_modified + + def get_presigned_url(self, path: str): + return f"https://{PresignedUrlDownloadServerState.HOSTNAME}{path}" + + def get_content(self, request: requests.Request, api_used: str): + self.requested = True + self.api_used = api_used + offset = 0 + end_byte_offset = len(self.content) - 1 + + if "Range" in request.headers: + offset, end_byte_offset = Utils.parse_range_header(request.headers["Range"], len(self.content)) + + resp = self.get_header(request) + resp.status_code = 206 if "Range" in request.headers else 200 + resp._content = self.content[offset : end_byte_offset + 1] + resp.headers["Content_Length"] = str(len(resp._content)) + return resp + + def get_header(self, request: requests.Request) -> requests.Response: + resp = requests.Response() + resp.status_code = 200 + resp._content = b"" + resp.request = request + resp.headers["Content-Length"] = str(self.file_size) + resp.headers["Content-Type"] = "application/octet-stream" + resp.headers["Last-Modified"] = self.last_modified + return resp + + +class PresignedUrlDownloadTestCase: + _FILE_PATH = "/testfile/remote/path" # A fake path for the remote location of the file to be downloaded + presigned_url_disabled_response = """ + { + "error_code": "PERMISSION_DENIED", + "message": "Presigned URLs API is not enabled", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "FILES_API_API_IS_NOT_ENABLED", + "domain": "filesystem.databricks.com", + "metadata": { + "api_name": "Presigned URLs" + } + }, + { + "@type": "type.googleapis.com/google.rpc.RequestInfo", + "request_id": "9ccb2aa8-621e-42f7-a815-828b70653bf6", + "serving_data": "" + } + ] + } + """ + + expired_url_aws_response = ( + '' + "AuthenticationFailedServer failed to authenticate " + "the request. Make sure the value of Authorization header is formed " + "correctly including the signature.\nRequestId:1abde581-601e-0028-" + "4a6d-5c3952000000\nTime:2025-01-01T16:54:20.5343181ZSignature not valid in the specified " + "time frame: Start [Wed, 01 Jan 2025 16:38:41 GMT] - Expiry [Wed, " + "01 Jan 2025 16:53:45 GMT] - Current [Wed, 01 Jan 2025 16:54:20 " + "GMT]" + ) + + expired_url_azure_response: str = ( + '\nAccessDenied' + "Request has expired" + "142025-01-01T17:47:13Z" + "2025-01-01T17:48:01Z" + "JY66KDXM4CXBZ7X2n8Qayqg60rbvut9P7pk0" + "" + ) + + def __init__( + self, + name: str, + file_size: int, + custom_response_get_file_status_api: Optional[CustomResponse] = CustomResponse(enabled=False), + custom_response_create_presigned_url: Optional[CustomResponse] = CustomResponse(enabled=False), + custom_response_download_from_url: Optional[CustomResponse] = CustomResponse(enabled=False), + custom_response_download_from_files_api: Optional[CustomResponse] = CustomResponse(enabled=False), + download_mode: Optional[Union[DownloadMode, List[DownloadMode]]] = None, + overwrite: bool = True, + use_parallel: Optional[Union[bool, List[bool]]] = None, + parallelism: Optional[int] = None, + parallel_download_min_file_size: Optional[int] = None, + parallel_upload_part_size: Optional[int] = None, + expected_exception_type: Optional[Type[BaseException]] = None, + expected_download_api: Optional[str] = None, + ): + # Metadata + self.name = name + self.file_size = file_size + self.last_modified = "Thu, 28 Nov 2024 16:39:14 GMT" + + # Function stubs to customize responses for various API calls + self.custom_response_get_file_status_api = custom_response_get_file_status_api + self.custom_response_create_presigned_url = custom_response_create_presigned_url + self.custom_response_download_from_url = custom_response_download_from_url + self.custom_response_download_from_files_api = custom_response_download_from_files_api + + # Parameters + # - By default, we would like to test both download modes for every test case. + if download_mode is None: + self.download_mode = [DownloadMode.STREAM, DownloadMode.FILE] + elif not isinstance(download_mode, list): + self.download_mode = [download_mode] + else: + self.download_mode = download_mode + self.overwrite = overwrite + # - By default, we would like to test both parallel and non-parallel downloads for every test case. + if use_parallel is None: + self.use_parallel = [False, True] + elif not isinstance(use_parallel, list): + self.use_parallel = [use_parallel] + else: + self.use_parallel = use_parallel + self.parallelism = parallelism + + # Expected results + self.expected_exception_type = expected_exception_type + self.expected_download_api = expected_download_api + + # Config overrides + self.parallel_download_min_file_size = parallel_download_min_file_size + self.parallel_upload_part_size = parallel_upload_part_size + + def _clear_state(self): + self.custom_response_get_file_status_api.clear_state() + self.custom_response_create_presigned_url.clear_state() + self.custom_response_download_from_files_api.clear_state() + self.custom_response_download_from_url.clear_state() + + def __str__(self) -> str: + return self.name + + @staticmethod + def to_string(test_case) -> str: + return str(test_case) + + def match_request_to_response( + self, request: requests.Request, server_state: PresignedUrlDownloadServerState + ) -> Optional[requests.Response]: + """Match the request to the server state and return a mock response.""" + request_url = urlparse(request.url) + request_query = parse_qs(request_url.query) + + # Create Download URL request + if ( + request_url.hostname == "localhost" + and request.method == "POST" + and request_url.path == "/api/2.0/fs/create-download-url" + ): + assert "path" in request_query, "Expected 'path' in query parameters" + file_path = request_query.get("path")[0] + + def processor() -> list: + url = server_state.get_presigned_url(file_path) + return [200, json.dumps({"url": url, "headers": {}}), {}] + + return self.custom_response_create_presigned_url.generate_response(request, processor) + + # Get files status request + elif ( + request_url.hostname == "localhost" + and request.method == "HEAD" + and request_url.path == f"/api/2.0/fs/files{self._FILE_PATH}" + ): + # HEAD request to check if file exists + def processor() -> list: + resp = server_state.get_header(request) + return [resp.status_code, resp._content, resp.headers] + + return self.custom_response_get_file_status_api.generate_response(request, processor, stream=True) + + # Direct Files API download request + elif ( + request_url.hostname == "localhost" + and request.method == "GET" + and request_url.path == f"/api/2.0/fs/files{self._FILE_PATH}" + ): + + def processor() -> list: + resp = server_state.get_content(request, api_used="files_api") + return [resp.status_code, resp._content, resp.headers] + + return self.custom_response_download_from_files_api.generate_response(request, processor, stream=True) + + # Download from Presigned URL request + elif request_url.hostname == PresignedUrlDownloadServerState.HOSTNAME and request.method == "GET": + logger.debug(f"headers = {request.headers}") + + def processor() -> list: + resp = server_state.get_content(request, api_used="presigned_url") + return [resp.status_code, resp._content, resp.headers] + + return self.custom_response_download_from_url.generate_response(request, processor, stream=True) + + else: + raise RuntimeError("Unexpected request " + str(request)) + + def run_one_case(self, config: Config, monkeypatch, download_mode: DownloadMode, use_parallel: bool) -> None: + if use_parallel and platform.system() == "Windows": + logger.debug("Parallel download is not supported on Windows. Falling back to sequential download.") + return + config = config.copy() + config.enable_presigned_download_api = True + config._clock = FakeClock() + if self.parallel_download_min_file_size is not None: + config.files_ext_parallel_download_min_file_size = self.parallel_download_min_file_size + if self.parallel_upload_part_size is not None: + config.files_ext_parallel_upload_part_size = self.parallel_upload_part_size + + w = WorkspaceClient(config=config) + state = PresignedUrlDownloadServerState(self.file_size, self.last_modified) + + with requests_mock.Mocker() as session_mock: + + def custom_matcher(request: requests.Request) -> Optional[requests.Response]: + """Custom matcher to handle requests and return mock responses.""" + return self.match_request_to_response(request, state) + + session_mock.add_matcher(custom_matcher) + + if download_mode == DownloadMode.STREAM: + if self.expected_exception_type is not None: + with pytest.raises(self.expected_exception_type): + w.files.download(PresignedUrlDownloadTestCase._FILE_PATH) + else: + download_resp = w.files.download(PresignedUrlDownloadTestCase._FILE_PATH) + assert download_resp.content_length == self.file_size + assert download_resp.contents.read() == state.content + if self.expected_download_api is not None: + assert state.api_used == self.expected_download_api + elif download_mode == DownloadMode.FILE: + with NamedTemporaryFile(delete=False) as temp_file: + file_path = temp_file.name + + # We can't use 'with' because Windows doesn't allow reopening the file, and `download_to` can open it + # only after we close it here. + try: + if self.expected_exception_type is not None: + with pytest.raises(self.expected_exception_type): + w.files.download_to( + PresignedUrlDownloadTestCase._FILE_PATH, + file_path, + overwrite=self.overwrite, + use_parallel=use_parallel, + parallelism=self.parallelism, + ) + else: + w.files.download_to( + PresignedUrlDownloadTestCase._FILE_PATH, + file_path, + overwrite=self.overwrite, + use_parallel=use_parallel, + parallelism=self.parallelism, + ) + with open(file_path, "rb") as f: + actual_content = f.read() + assert len(actual_content) == len(state.content) + assert actual_content == state.content + if self.expected_download_api is not None: + assert state.api_used == self.expected_download_api + finally: + if os.path.exists(file_path): + os.remove(file_path) + else: + raise ValueError("Unexpected download mode") + + def run(self, config: Config, monkeypatch) -> None: + # Run all combinations of download modes and parallelism settings + for download_mode in self.download_mode: + for use_parallel in self.use_parallel: + logger.info(f"Downloading {download_mode.name} with parallelism={use_parallel}") + self.run_one_case(config, monkeypatch, download_mode, use_parallel) + self._clear_state() + + +@pytest.mark.parametrize( + "test_case", + [ + # Happy cases + PresignedUrlDownloadTestCase( + name="Presigned URL download succeeds", + file_size=100 * 1024 * 1024, + ), + PresignedUrlDownloadTestCase( + name="Presigned URL download to File succeeds", + file_size=100 * 1024 * 1024, + download_mode=DownloadMode.FILE, + ), + PresignedUrlDownloadTestCase( + name="Presigned URL download to File in parallel succeeds", + file_size=100 * 1024 * 1024, + download_mode=DownloadMode.FILE, + use_parallel=True, + parallelism=2, + ), + # Sad cases + PresignedUrlDownloadTestCase( + name="Presigned URL download fails with 403", + file_size=100 * 1024 * 1024, + expected_exception_type=PermissionDenied, + custom_response_create_presigned_url=CustomResponse(code=403, only_invocation=1), + ), + PresignedUrlDownloadTestCase( + name="Presigned URL download fails with 500 when creating presigned URL", + file_size=100 * 1024 * 1024, + expected_exception_type=InternalError, + custom_response_create_presigned_url=CustomResponse(code=500, only_invocation=1), + ), + PresignedUrlDownloadTestCase( + name="Presigned URL download fails with 500 when downloading from URL", + file_size=100 * 1024 * 1024, + expected_exception_type=TimeoutError, # TimeoutError is raised after retries are exhausted + custom_response_download_from_url=CustomResponse(code=500), + ), + PresignedUrlDownloadTestCase( + name="Intermittent error fails after retry: Presigned URL expires with 403 when downloading from URL", + file_size=100 * 1024 * 1024, + expected_exception_type=ValueError, + custom_response_download_from_url=CustomResponse( + code=403, body=PresignedUrlDownloadTestCase.expired_url_aws_response + ), + expected_download_api="presigned_url", + use_parallel=True, + download_mode=DownloadMode.FILE, + ), + # Recoverable errors + PresignedUrlDownloadTestCase( + name="Intermittent error should succeed after retry: Presigned URL download fails with 500 when downloading from URL", + file_size=100 * 1024 * 1024, + custom_response_download_from_url=CustomResponse(code=500, only_invocation=1), + ), + PresignedUrlDownloadTestCase( + name="Intermittent error should succeed after retry: Presigned URL expires with 403 when downloading from URL", + file_size=100 * 1024 * 1024, + custom_response_download_from_url=CustomResponse( + code=403, + first_invocation=2, + last_invocation=4, + body=PresignedUrlDownloadTestCase.expired_url_aws_response, + ), + ), + # Test fallback to Files API + PresignedUrlDownloadTestCase( + name="Presigned URL is disabled, should fallback to Files API", + file_size=100 * 1024 * 1024, + expected_download_api="files_api", + custom_response_create_presigned_url=CustomResponse( + code=403, only_invocation=1, body=PresignedUrlDownloadTestCase.presigned_url_disabled_response + ), + ), + PresignedUrlDownloadTestCase( + name="Presigned URL fails with 403 when downloading, should fallback to Files API", + file_size=100 * 1024 * 1024, + expected_download_api="files_api", + custom_response_download_from_url=CustomResponse(code=403, only_invocation=1), + ), + ], + ids=PresignedUrlDownloadTestCase.to_string, +) +def test_presigned_url_download(config: Config, test_case: PresignedUrlDownloadTestCase, monkeypatch) -> None: + test_case.run(config, monkeypatch) class FileContent: @@ -435,183 +1138,88 @@ class MultipartUploadServerState: upload_part_url_prefix = "https://cloud_provider.com/upload-part/" abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/" - def __init__(self): + def __init__(self, expected_part_size: Optional[int] = None): self.issued_multipart_urls = {} # part_number -> expiration_time self.uploaded_parts = {} # part_number -> [part file path, etag] self.session_token = "token-" + MultipartUploadServerState.randomstr() self.file_content = None self.issued_abort_url_expire_time = None self.aborted = False + self.expected_part_size = expected_part_size + self.global_lock = Lock() def create_upload_part_url(self, path: str, part_number: int, expire_time: datetime) -> str: - assert not self.aborted - # client may have requested a URL for the same part if retrying on network error - self.issued_multipart_urls[part_number] = expire_time - return f"{self.upload_part_url_prefix}{path}/{part_number}" - - def create_abort_url(self, path: str, expire_time: datetime) -> str: - assert not self.aborted - self.issued_abort_url_expire_time = expire_time - return f"{self.abort_upload_url_prefix}{path}" - - def save_part(self, part_number: int, part_content: bytes, etag: str) -> None: - assert not self.aborted - assert len(part_content) > 0 - - logger.info(f"Saving part {part_number} of size {len(part_content)}") - - # part might already have been uploaded - existing_part = self.uploaded_parts.get(part_number) - if existing_part: - part_file = existing_part[0] - with open(part_file, "wb") as f: # overwrite - f.write(part_content) - else: - fd, part_file = mkstemp() - with open(fd, "wb") as f: - f.write(part_content) - - self.uploaded_parts[part_number] = [part_file, etag] - - def cleanup(self) -> None: - for [file, _] in self.uploaded_parts.values(): - os.remove(file) - - def get_file_content(self) -> Optional[FileContent]: - if self.aborted: - assert not self.file_content - - # content may be None even for a non-aborted upload, - # in case single-shot upload was used due to small stream size. - return self.file_content - - def upload_complete(self, etags: dict) -> None: - assert not self.aborted - # validate etags - expected_etags = {} - for part_number in self.uploaded_parts.keys(): - expected_etags[part_number] = self.uploaded_parts[part_number][1] - assert etags == expected_etags - - size = 0 - sha256 = hashlib.sha256() - - sorted_parts = sorted(self.uploaded_parts.keys()) - for part_number in sorted_parts: - [part_path, _] = self.uploaded_parts[part_number] - size += os.path.getsize(part_path) - with open(part_path, "rb") as f: - part_content = f.read() - sha256.update(part_content) - - self.file_content = FileContent(size, sha256.hexdigest()) - - def abort_upload(self) -> None: - self.aborted = True - - @staticmethod - def randomstr() -> str: - return f"{random.randrange(10000)}-{int(time.time())}" - - -class CustomResponse: - """Custom response allows to override the "default" response generated by the server - with the "custom" response to simulate failure error code, unexpected response body or - network error. - - The server is represented by the `processor` parameter in `generate_response()` call. - """ - - def __init__( - self, - # If False, default response is always returned. - # If True, response is defined by the current invocation count - # with respect to first_invocation / last_invocation / only_invocation - enabled: bool = True, - # Custom code to return - code: Optional[int] = 200, - # Custom body to return - body: Optional[str] = None, - # Custom exception to raise - exception: Optional[Type[BaseException]] = None, - # Whether exception should be raised before calling processor() - # (so changing server state) - exception_happened_before_processing: bool = False, - # First invocation (1-based) at which return custom response - first_invocation: Optional[int] = None, - # Last invocation (1-based) at which return custom response - last_invocation: Optional[int] = None, - # Only invocation (1-based) at which return custom response - only_invocation: Optional[int] = None, - ): - self.enabled = enabled - self.code = code - self.body = body - self.exception = exception - self.exception_happened_before_processing = exception_happened_before_processing - self.first_invocation = first_invocation - self.last_invocation = last_invocation - self.only_invocation = only_invocation - - if self.only_invocation and (self.first_invocation or self.last_invocation): - raise ValueError("Cannot set both only invocation and first/last invocation") - - if self.exception_happened_before_processing and not self.exception: - raise ValueError("Exception is not defined") - - self.invocation_count = 0 + assert not self.aborted + # client may have requested a URL for the same part if retrying on network error + self.issued_multipart_urls[part_number] = expire_time + return f"{self.upload_part_url_prefix}{path}/{part_number}" - def invocation_matches(self) -> bool: - if not self.enabled: - return False + def create_abort_url(self, path: str, expire_time: datetime) -> str: + assert not self.aborted + self.issued_abort_url_expire_time = expire_time + return f"{self.abort_upload_url_prefix}{path}" - self.invocation_count += 1 + def save_part(self, part_number: int, part_content: bytes, etag: str) -> None: + assert not self.aborted + assert len(part_content) > 0 + if self.expected_part_size is not None: + assert len(part_content) <= self.expected_part_size - if self.only_invocation: - return self.invocation_count == self.only_invocation + logger.info(f"Saving part {part_number} of size {len(part_content)}") - if self.first_invocation and self.invocation_count < self.first_invocation: - return False - if self.last_invocation and self.invocation_count > self.last_invocation: - return False - return True + # part might already have been uploaded + with self.global_lock: + if part_number not in self.uploaded_parts: + with NamedTemporaryFile(mode="wb", delete=False) as f: + part_file = f.name + self.uploaded_parts[part_number] = [part_file, etag, Lock()] + existing_part = self.uploaded_parts[part_number] + with existing_part[2]: # lock per part + part_file = existing_part[0] + with open(part_file, "wb") as f: # overwrite + f.write(part_content) + existing_part[1] = etag # update etag - def generate_response(self, request: requests.Request, processor: Callable[[], list]) -> requests.Response: - activate_for_current_invocation = self.invocation_matches() + def cleanup(self) -> None: + for [file, _, _] in self.uploaded_parts.values(): + os.remove(file) - if activate_for_current_invocation and self.exception and self.exception_happened_before_processing: - # if network exception is thrown while processing a request, it's not defined - # if server actually processed the request (and so changed its state) - raise self.exception + def get_file_content(self) -> Optional[FileContent]: + if self.aborted: + assert not self.file_content, "File content should not be set if upload was aborted" - custom_response = [self.code, self.body or "", {}] + # content may be None even for a non-aborted upload, + # in case single-shot upload was used due to small stream size. + return self.file_content - if activate_for_current_invocation: - if self.code and 400 <= self.code < 500: - # if server returns client error, it's not supposed to change its state, - # so we're not calling processor() - [code, body, headers] = custom_response - else: - # we're calling processor() but override its response with the custom one - processor() - [code, body, headers] = custom_response - else: - [code, body, headers] = processor() + def upload_complete(self, etags: dict) -> None: + assert not self.aborted + # validate etags + expected_etags = {} + with self.global_lock: + for part_number in self.uploaded_parts.keys(): + expected_etags[part_number] = self.uploaded_parts[part_number][1] + assert etags == expected_etags - if activate_for_current_invocation and self.exception: - # self.exception_happened_before_processing is False - raise self.exception + size = 0 + sha256 = hashlib.sha256() - resp = requests.Response() + sorted_parts = sorted(self.uploaded_parts.keys()) + for part_number in sorted_parts: + part_path = self.uploaded_parts[part_number][0] + size += os.path.getsize(part_path) + with open(part_path, "rb") as f: + part_content = f.read() + sha256.update(part_content) - resp.request = request - resp.status_code = code - resp._content = body.encode() + self.file_content = FileContent(size, sha256.hexdigest()) - for key in headers: - resp.headers[key] = headers[key] + def abort_upload(self) -> None: + self.aborted = True - return resp + @staticmethod + def randomstr() -> str: + return f"{random.randrange(10000)}-{int(time.time())}" class SingleShotUploadServerState: @@ -637,9 +1245,13 @@ def __init__( self, name: str, stream_size: int, + cloud: Cloud, overwrite: bool, + source_type: List["UploadSourceType"], + use_parallel: List[bool], + parallelism: Optional[int], multipart_upload_min_stream_size: int, - multipart_upload_chunk_size: Optional[int], + multipart_upload_part_size: Optional[int], sdk_retry_timeout_seconds: Optional[int], multipart_upload_max_retries: Optional[int], custom_response_on_single_shot_upload: CustomResponse, @@ -651,9 +1263,13 @@ def __init__( ): self.name = name self.stream_size = stream_size + self.cloud = cloud self.overwrite = overwrite + self.source_type = source_type + self.use_parallel = use_parallel + self.parallelism = parallelism self.multipart_upload_min_stream_size = multipart_upload_min_stream_size - self.multipart_upload_chunk_size = multipart_upload_chunk_size + self.multipart_upload_part_size = multipart_upload_part_size self.sdk_retry_timeout_seconds = sdk_retry_timeout_seconds self.multipart_upload_max_retries = multipart_upload_max_retries self.custom_response_on_single_shot_upload = custom_response_on_single_shot_upload @@ -662,6 +1278,7 @@ def __init__( self.expected_single_shot_upload = expected_single_shot_upload self.path = "/test.txt" + self.created_temp_files = [] def customize_config(self, config: Config) -> None: pass @@ -669,23 +1286,54 @@ def customize_config(self, config: Config) -> None: def create_multipart_upload_server_state(self) -> Union[MultipartUploadServerState, "ResumableUploadServerState"]: raise NotImplementedError + def clear_state(self) -> None: + for file_path in self.created_temp_files: + try: + os.remove(file_path) + except OSError: + logger.warning("Failed to remove temp file: %s", file_path) + self.created_temp_files = [] + + def get_upload_file(self, content: bytes, source_type: "UploadSourceType") -> Union[str, io.BytesIO]: + """Returns a file or stream to upload based on the source type.""" + if source_type == UploadSourceType.FILE: + with NamedTemporaryFile(mode="wb", delete=False) as f: + f.write(content) + file_path = f.name + self.created_temp_files.append(file_path) + return file_path + elif source_type == UploadSourceType.STREAM: + return io.BytesIO(content) + else: + raise ValueError(f"Unknown source type: {source_type}") + def match_request_to_response( self, request: requests.Request, server_state: Union[MultipartUploadServerState, "ResumableUploadServerState"] ) -> Optional[requests.Response]: raise NotImplementedError def run(self, config: Config) -> None: + for source_type in self.source_type: + for use_parallel in self.use_parallel: + self.run_one_case(config, use_parallel, source_type) + + def run_one_case(self, config: Config, use_parallel: bool, source_type: "UploadSourceType") -> None: + + logger.debug(f"Running test case: {self.name}, source_type={source_type}, use_parallel={use_parallel}") config = config.copy() - config.enable_experimental_files_api_client = True + config._clock = FakeClock() + + if self.cloud: + config.databricks_environment = DatabricksEnvironment(self.cloud, "") if self.sdk_retry_timeout_seconds: config.retry_timeout_seconds = self.sdk_retry_timeout_seconds - if self.multipart_upload_chunk_size: - config.multipart_upload_chunk_size = self.multipart_upload_chunk_size + if self.multipart_upload_part_size: + config.multipart_upload_part_size = self.multipart_upload_part_size if self.multipart_upload_max_retries: - config.multipart_upload_max_retries = self.multipart_upload_max_retries + config.files_ext_multipart_upload_max_retries = self.multipart_upload_max_retries - config.multipart_upload_min_stream_size = self.multipart_upload_min_stream_size + config.files_ext_multipart_upload_min_stream_size = self.multipart_upload_min_stream_size pat_token = "some_pat_token" config._header_factory = lambda: {"Authorization": f"Bearer {pat_token}"} @@ -695,7 +1343,8 @@ def run(self, config: Config) -> None: multipart_server_state = self.create_multipart_upload_server_state() single_shot_server_state = SingleShotUploadServerState() - file_content = os.urandom(self.stream_size) + file_content = fast_random_bytes(self.stream_size) + content_or_source = self.get_upload_file(file_content, source_type) w = WorkspaceClient(config=config) try: @@ -724,32 +1373,69 @@ def processor() -> list: session_mock.add_matcher(matcher=custom_matcher) def upload() -> None: - w.files.upload(self.path, io.BytesIO(file_content), overwrite=self.overwrite) + if source_type == UploadSourceType.FILE: + w.files.upload_from( + self.path, + content_or_source, + overwrite=self.overwrite, + part_size=self.multipart_upload_part_size, + use_parallel=use_parallel, + parallelism=self.parallelism, + ) + else: + w.files.upload( + self.path, + content_or_source, + overwrite=self.overwrite, + part_size=self.multipart_upload_part_size, + use_parallel=use_parallel, + parallelism=self.parallelism, + ) if self.expected_exception_type is not None: with pytest.raises(self.expected_exception_type): upload() - assert not single_shot_server_state.get_file_content() - assert not multipart_server_state.get_file_content() + assert ( + not single_shot_server_state.get_file_content() + ), "Single-shot upload should not have succeeded" + assert not multipart_server_state.get_file_content(), "Multipart upload should not have succeeded" else: upload() if self.expected_single_shot_upload: - assert single_shot_server_state.get_file_content() == FileContent.from_bytes(file_content) - assert not multipart_server_state.get_file_content() + assert single_shot_server_state.get_file_content() == FileContent.from_bytes( + file_content + ), "Single-shot upload should have succeeded" + assert ( + not multipart_server_state.get_file_content() + ), "Multipart upload should not have succeeded" else: - assert multipart_server_state.get_file_content() == FileContent.from_bytes(file_content) - assert not single_shot_server_state.get_file_content() + assert multipart_server_state.get_file_content() == FileContent.from_bytes( + file_content + ), "Multipart upload should have succeeded" + assert ( + not single_shot_server_state.get_file_content() + ), "Single-shot upload should not have succeeded" - assert multipart_server_state.aborted == self.expected_multipart_upload_aborted + assert ( + multipart_server_state.aborted == self.expected_multipart_upload_aborted + ), "Multipart upload aborted state mismatch" finally: multipart_server_state.cleanup() + self.clear_state() @staticmethod def is_auth_header_present(r: requests.Request) -> bool: return r.headers.get("Authorization") is not None +class UploadSourceType(Enum): + """Source type for the upload. Used to determine how to upload the file.""" + + FILE = "file" # upload from a file on disk + STREAM = "stream" # upload from a stream (e.g. BytesIO) + + class MultipartUploadTestCase(UploadTestCase): """Test case for multipart upload of a file. Multipart uploads are used on AWS and Azure. @@ -767,7 +1453,7 @@ class MultipartUploadTestCase(UploadTestCase): Response of each call can be modified by parameterising a respective `CustomResponse` object. """ - expired_url_aws_response: str = ( + expired_url_aws_response = ( '' "AuthenticationFailedServer failed to authenticate " "the request. Make sure the value of Authorization header is formed " @@ -788,13 +1474,39 @@ class MultipartUploadTestCase(UploadTestCase): "" ) + presigned_url_disabled_response = """ + { + "error_code": "PERMISSION_DENIED", + "message": "Presigned URLs API is not enabled", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "FILES_API_API_IS_NOT_ENABLED", + "domain": "filesystem.databricks.com", + "metadata": { + "api_name": "Presigned URLs" + } + }, + { + "@type": "type.googleapis.com/google.rpc.RequestInfo", + "request_id": "9ccb2aa8-621e-42f7-a815-828b70653bf6", + "serving_data": "" + } + ] + } + """ + def __init__( self, name: str, - stream_size: int, # size of uploaded file or, technically, stream - overwrite: bool = True, # TODO test for overwrite = false + content_size: int, # size of uploaded file or, technically, stream + cloud: Cloud = Cloud.AWS, + overwrite: bool = True, multipart_upload_min_stream_size: int = 0, # disable single-shot uploads by default - multipart_upload_chunk_size: Optional[int] = None, + source_type: Optional[List[UploadSourceType]] = None, + use_parallel: Optional[List[bool]] = None, + parallelism: Optional[int] = None, + multipart_upload_part_size: Optional[int] = None, sdk_retry_timeout_seconds: Optional[int] = None, multipart_upload_max_retries: Optional[int] = None, multipart_upload_batch_url_count: Optional[int] = None, @@ -808,15 +1520,21 @@ def __init__( # exception which is expected to be thrown (so upload is expected to have failed) expected_exception_type: Optional[Type[BaseException]] = None, # if abort is expected to be called + # expected part size + expected_part_size: Optional[int] = None, expected_multipart_upload_aborted: bool = False, expected_single_shot_upload: bool = False, ): super().__init__( name, - stream_size, + content_size, + cloud, overwrite, + source_type or [UploadSourceType.FILE, UploadSourceType.STREAM], + use_parallel or [False, True], + parallelism, multipart_upload_min_stream_size, - multipart_upload_chunk_size, + multipart_upload_part_size, sdk_retry_timeout_seconds, multipart_upload_max_retries, custom_response_on_single_shot_upload, @@ -832,13 +1550,24 @@ def __init__( self.custom_response_on_complete = copy.deepcopy(custom_response_on_complete) self.custom_response_on_create_abort_url = copy.deepcopy(custom_response_on_create_abort_url) self.custom_response_on_abort = copy.deepcopy(custom_response_on_abort) + self.expected_exception_type = expected_exception_type + self.expected_part_size = expected_part_size def customize_config(self, config: Config) -> None: if self.multipart_upload_batch_url_count: - config.multipart_upload_batch_url_count = self.multipart_upload_batch_url_count + config.files_ext_multipart_upload_batch_url_count = self.multipart_upload_batch_url_count def create_multipart_upload_server_state(self) -> MultipartUploadServerState: - return MultipartUploadServerState() + return MultipartUploadServerState(self.expected_part_size) + + def clear_state(self) -> None: + super().clear_state() + self.custom_response_on_initiate.clear_state() + self.custom_response_on_create_multipart_url.clear_state() + self.custom_response_on_upload.clear_state() + self.custom_response_on_complete.clear_state() + self.custom_response_on_create_abort_url.clear_state() + self.custom_response_on_abort.clear_state() def match_request_to_response( self, request: requests.Request, server_state: MultipartUploadServerState @@ -970,6 +1699,21 @@ def processor() -> list: return self.custom_response_on_abort.generate_response(request, processor) + # direct upload (single-shot upload) + elif ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{self.path}" + and request.method == "PUT" + ): + assert MultipartUploadTestCase.is_auth_header_present(request) + assert request.content is not None + + def processor(): + server_state.file_content = FileContent.from_bytes(request.content) + return [200, "", {}] + + return self.custom_response_on_upload.generate_response(request, processor) + return None @staticmethod @@ -992,10 +1736,50 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: @pytest.mark.parametrize( "test_case", [ + # -------------------------- happy cases -------------------------- + MultipartUploadTestCase( + "Multipart upload successful: single part", + content_size=1024 * 1024, # less than part size + multipart_upload_part_size=10 * 1024 * 1024, + expected_part_size=1024 * 1024, # chunk size is used + ), + MultipartUploadTestCase( + "Multipart upload successful: multiple parts (aligned)", + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, + expected_part_size=10 * 1024 * 1024, # chunk size is used + ), + MultipartUploadTestCase( + "Multipart upload successful: multiple parts (aligned), upload urls by 3", + multipart_upload_batch_url_count=3, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, + expected_part_size=10 * 1024 * 1024, # chunk size is used + ), + MultipartUploadTestCase( + "Multipart upload successful: multiple chunks (not aligned), upload urls by 1", + content_size=100 * 1024 * 1024 + 1566, # 14 full chunks + remainder + multipart_upload_part_size=7 * 1024 * 1024 - 17, + expected_part_size=7 * 1024 * 1024 - 17, # chunk size is used + ), + MultipartUploadTestCase( + "Multipart upload successful: multiple parts (not aligned), upload urls by 5", + multipart_upload_batch_url_count=5, + content_size=100 * 1024 * 1024 + 1566, # 14 full parts + remainder + multipart_upload_part_size=7 * 1024 * 1024 - 17, + ), + MultipartUploadTestCase( + "Small stream, single-shot upload used", + content_size=1024 * 1024, + multipart_upload_min_stream_size=1024 * 1024 + 1, + expected_multipart_upload_aborted=False, + expected_single_shot_upload=True, + use_parallel=[False], + ), # -------------------------- failures on "initiate upload" -------------------------- MultipartUploadTestCase( "Initiate: 400 response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, multipart_upload_min_stream_size=1024 * 1024, # still multipart upload is used custom_response_on_initiate=CustomResponse(code=400, only_invocation=1), expected_exception_type=BadRequest, @@ -1003,35 +1787,35 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), MultipartUploadTestCase( "Initiate: 403 response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_initiate=CustomResponse(code=403, only_invocation=1), expected_exception_type=PermissionDenied, expected_multipart_upload_aborted=False, # upload didn't start ), MultipartUploadTestCase( "Initiate: 500 response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_initiate=CustomResponse(code=500, only_invocation=1), expected_exception_type=InternalError, expected_multipart_upload_aborted=False, # upload didn't start ), MultipartUploadTestCase( "Initiate: non-JSON response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_initiate=CustomResponse(body="this is not a JSON", only_invocation=1), expected_exception_type=requests.exceptions.JSONDecodeError, expected_multipart_upload_aborted=False, # upload didn't start ), MultipartUploadTestCase( "Initiate: meaningless JSON response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_initiate=CustomResponse(body='{"foo": 123}', only_invocation=1), expected_exception_type=ValueError, expected_multipart_upload_aborted=False, # upload didn't start ), MultipartUploadTestCase( "Initiate: no session token in response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_initiate=CustomResponse( body='{"multipart_upload":{"session_token1": "token123"}}', only_invocation=1 ), @@ -1040,15 +1824,14 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), MultipartUploadTestCase( "Initiate: permanent retryable exception", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_initiate=CustomResponse(exception=requests.ConnectionError), - sdk_retry_timeout_seconds=30, # let's not wait 5 min (SDK default timeout) expected_exception_type=TimeoutError, # SDK throws this if retries are taking too long expected_multipart_upload_aborted=False, # upload didn't start ), MultipartUploadTestCase( "Initiate: intermittent retryable exception", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_initiate=CustomResponse( exception=requests.ConnectionError, # 3 calls fail, but request is successfully retried @@ -1059,7 +1842,7 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), MultipartUploadTestCase( "Initiate: intermittent retryable status code", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_initiate=CustomResponse( code=429, # 3 calls fail, then retry succeeds @@ -1071,7 +1854,7 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: # -------------------------- failures on "create upload URL" -------------------------- MultipartUploadTestCase( "Create upload URL: 400 response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( code=400, # 1 failure is enough @@ -1082,52 +1865,62 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), MultipartUploadTestCase( "Create upload URL: 403 response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + code=403, + # 1 failure is enough + only_invocation=1, + ), + expected_exception_type=PermissionDenied, + expected_multipart_upload_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: internal error is not retried", + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), expected_exception_type=InternalError, expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( "Create upload URL: non-JSON response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(body="this is not a JSON", only_invocation=1), expected_exception_type=requests.exceptions.JSONDecodeError, expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( "Create upload URL: meaningless JSON response is not retried", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(body='{"foo":123}', only_invocation=1), expected_exception_type=ValueError, expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( "Create upload URL: meaningless JSON response is not retried 2", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(body='{"upload_part_urls":[]}', only_invocation=1), expected_exception_type=ValueError, expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( "Create upload URL: meaningless JSON response is not retried 3", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( body='{"upload_part_urls":[{"url":""}]}', only_invocation=1 ), - expected_exception_type=KeyError, # TODO we might want to make JSON parsing more reliable + expected_exception_type=KeyError, expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( "Create upload URL: permanent retryable exception", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(exception=requests.ConnectionError), - sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) expected_exception_type=TimeoutError, expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( "Create upload URL: intermittent retryable exception", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( exception=requests.Timeout, # happens only once, retry succeeds @@ -1137,8 +1930,8 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), MultipartUploadTestCase( "Create upload URL: intermittent retryable exception 2", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( exception=requests.Timeout, # 4th request for multipart URLs fails 3 times, then retry succeeds @@ -1149,8 +1942,8 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), MultipartUploadTestCase( "Create upload URL: intermittent retryable exception 3", - stream_size=1024 * 1024, - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=1024 * 1024, + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( code=500, first_invocation=4, @@ -1158,23 +1951,75 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), expected_multipart_upload_aborted=False, ), + MultipartUploadTestCase( + "Create upload URL: fallback to single-shot upload when presigned URLs are disabled", + content_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + code=403, + body=MultipartUploadTestCase.presigned_url_disabled_response, + # 1 failure is enough + only_invocation=1, + ), + expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, + ), # -------------------------- failures on part upload -------------------------- MultipartUploadTestCase( - "Upload part: 403 response is not retried", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, + "Upload part: 403 response will trigger fallback to single-shot upload on Azure", + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=403, # fail only once only_invocation=1, ), + expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, + ), + MultipartUploadTestCase( + "Upload part: 403 response will trigger fallback to single-shot upload on AWS", + cloud=Cloud.AWS, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + # fail only once on the first part + only_invocation=1, + ), + expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, + ), + MultipartUploadTestCase( + "Upload part: fallback to single-shot upload when Azure Firewall denies first part upload", + cloud=Cloud.AZURE, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + # fail only once on the first part + only_invocation=1, + ), + expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, + ), + MultipartUploadTestCase( + "Upload part: 403 response on the second part on Azure causes permission denied", + cloud=Cloud.AZURE, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + # fail only once on the second part + only_invocation=2, + ), expected_exception_type=PermissionDenied, expected_multipart_upload_aborted=True, + use_parallel=[False], # "Second part" is not well-defined when using parallel upload ), MultipartUploadTestCase( "Upload part: 400 response is not retried", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=400, # fail once, but not on the first part @@ -1185,8 +2030,8 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), MultipartUploadTestCase( "Upload part: expired URL is retried on AWS", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=403, body=MultipartUploadTestCase.expired_url_aws_response, only_invocation=2 ), @@ -1195,8 +2040,8 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: MultipartUploadTestCase( "Upload part: expired URL is retried on Azure", multipart_upload_max_retries=3, - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=403, body=MultipartUploadTestCase.expired_url_azure_response, @@ -1210,8 +2055,8 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: "Upload part: expired URL is retried on Azure, requesting urls by 6", multipart_upload_max_retries=3, multipart_upload_batch_url_count=6, - stream_size=100 * 1024 * 1024, # 100 chunks - multipart_upload_chunk_size=1 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 100 chunks + multipart_upload_part_size=1 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=403, body=MultipartUploadTestCase.expired_url_azure_response, @@ -1224,8 +2069,8 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: MultipartUploadTestCase( "Upload part: expired URL retry is exhausted", multipart_upload_max_retries=3, - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=403, body=MultipartUploadTestCase.expired_url_azure_response, @@ -1235,29 +2080,42 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), expected_exception_type=ValueError, expected_multipart_upload_aborted=True, + use_parallel=[False], # to make "retry is exhausted" well-defined + ), + MultipartUploadTestCase( + "Upload part in parallel: expired URL retry is exhausted", + multipart_upload_max_retries=3, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + body=MultipartUploadTestCase.expired_url_azure_response, + first_invocation=2, # to exhaust retries for parallel uploading, failure must happen infinitely + ), + expected_exception_type=ValueError, + expected_multipart_upload_aborted=True, + use_parallel=[True], # to make "retry is exhausted" well-defined ), MultipartUploadTestCase( "Upload part: permanent retryable error", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, - sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse(exception=requests.ConnectionError, first_invocation=8), expected_exception_type=TimeoutError, expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( "Upload part: permanent retryable status code", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, - sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse(code=429, first_invocation=8), expected_exception_type=TimeoutError, expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( "Upload part: intermittent retryable error", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( exception=requests.ConnectionError, first_invocation=2, last_invocation=5 ), @@ -1265,50 +2123,30 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), MultipartUploadTestCase( "Upload part: intermittent retryable status code 429", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse(code=429, first_invocation=2, last_invocation=4), expected_multipart_upload_aborted=False, ), MultipartUploadTestCase( "Upload chunk: intermittent retryable status code 500", - stream_size=100 * 1024 * 1024, # 10 chunks - multipart_upload_chunk_size=10 * 1024 * 1024, + content_size=100 * 1024 * 1024, # 10 parts + multipart_upload_part_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse(code=500, first_invocation=2, last_invocation=4), expected_multipart_upload_aborted=False, ), # -------------------------- failures on abort -------------------------- - MultipartUploadTestCase( - "Abort URL: 403 response", - stream_size=1024 * 1024, - custom_response_on_upload=CustomResponse(code=403, only_invocation=1), - custom_response_on_create_abort_url=CustomResponse(code=403), - expected_exception_type=PermissionDenied, # original error - expected_multipart_upload_aborted=False, # server state didn't change to record abort - ), MultipartUploadTestCase( "Abort URL: intermittent retryable error", - stream_size=1024 * 1024, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1), custom_response_on_create_abort_url=CustomResponse(code=429, first_invocation=1, last_invocation=3), expected_exception_type=PermissionDenied, # original error expected_multipart_upload_aborted=True, # abort successfully called after abort URL creation is retried ), - MultipartUploadTestCase( - "Abort URL: intermittent retryable error 2", - stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1), - custom_response_on_create_abort_url=CustomResponse( - exception=requests.Timeout, first_invocation=1, last_invocation=3 - ), - expected_exception_type=PermissionDenied, # original error - expected_multipart_upload_aborted=True, # abort successfully called after abort URL creation is retried - ), MultipartUploadTestCase( "Abort: exception", - stream_size=1024 * 1024, - # don't wait for 5 min (SDK default timeout) - sdk_retry_timeout_seconds=30, + content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1), custom_response_on_abort=CustomResponse( exception=requests.Timeout, @@ -1318,40 +2156,22 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: expected_exception_type=PermissionDenied, # original error is reported expected_multipart_upload_aborted=True, ), - # -------------------------- happy cases -------------------------- - MultipartUploadTestCase( - "Multipart upload successful: single part", - stream_size=1024 * 1024, # less than part size - multipart_upload_chunk_size=10 * 1024 * 1024, - ), - MultipartUploadTestCase( - "Multipart upload successful: multiple parts (aligned)", - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, - ), + # -------------------------- Parallel Upload for Streams -------------------------- MultipartUploadTestCase( - "Multipart upload successful: multiple parts (aligned), upload urls by 3", - multipart_upload_batch_url_count=3, - stream_size=100 * 1024 * 1024, # 10 parts - multipart_upload_chunk_size=10 * 1024 * 1024, - ), - MultipartUploadTestCase( - "Multipart upload successful: multiple parts (not aligned), upload urls by 1", - stream_size=100 * 1024 * 1024 + 1566, # 14 full parts + remainder - multipart_upload_chunk_size=7 * 1024 * 1024 - 17, - ), - MultipartUploadTestCase( - "Multipart upload successful: multiple parts (not aligned), upload urls by 5", - multipart_upload_batch_url_count=5, - stream_size=100 * 1024 * 1024 + 1566, # 14 full parts + remainder - multipart_upload_chunk_size=7 * 1024 * 1024 - 17, - ), - MultipartUploadTestCase( - "Small stream, single-shot upload used", - stream_size=1024 * 1024, - multipart_upload_min_stream_size=1024 * 1024 + 1, - expected_multipart_upload_aborted=False, - expected_single_shot_upload=True, + "Multipart parallel upload for stream: Upload errors are not retried", + content_size=10 * 1024 * 1024, + multipart_upload_part_size=1024 * 1024, + source_type=[UploadSourceType.STREAM], + use_parallel=[True], + parallelism=1, + custom_response_on_upload=CustomResponse( + code=501, + first_invocation=2, + last_invocation=4, + delayed_response_seconds=0.1, + ), + expected_exception_type=NotImplemented, + expected_multipart_upload_aborted=True, ), ], ids=MultipartUploadTestCase.to_string, @@ -1366,18 +2186,22 @@ class ResumableUploadServerState: resumable_upload_url_prefix = "https://cloud_provider.com/resumable-upload/" abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/" - def __init__(self, unconfirmed_delta: Union[int, list]): + def __init__(self, unconfirmed_delta: Union[int, list], expected_part_size: Optional[int]): self.unconfirmed_delta = unconfirmed_delta self.confirmed_last_byte: Optional[int] = None # inclusive self.uploaded_parts = [] self.session_token = "token-" + MultipartUploadServerState.randomstr() self.file_content: Optional[FileContent] = None self.aborted = False + self.expected_part_size = expected_part_size def save_part(self, start_offset: int, end_offset_incl: int, part_content: bytes, file_size_s: str) -> None: assert not self.aborted assert len(part_content) > 0 + if self.expected_part_size is not None: + assert len(part_content) <= self.expected_part_size + if self.confirmed_last_byte: assert start_offset == self.confirmed_last_byte + 1 else: @@ -1408,8 +2232,8 @@ def save_part(self, start_offset: int, end_offset_incl: int, part_content: bytes if unconfirmed_delta > 0: part_content = part_content[:-unconfirmed_delta] - fd, part_file = mkstemp() - with open(fd, "wb") as f: + with NamedTemporaryFile(mode="wb", delete=False) as f: + part_file = f.name f.write(part_content) self.uploaded_parts.append(part_file) @@ -1468,9 +2292,13 @@ def __init__( self, name: str, stream_size: int, + cloud: Cloud = None, overwrite: bool = True, + source_type: Optional[List[UploadSourceType]] = None, + use_parallel: Optional[List[bool]] = None, + parallelism: Optional[int] = None, multipart_upload_min_stream_size: int = 0, # disable single-shot uploads by default - multipart_upload_chunk_size: Optional[int] = None, + multipart_upload_part_size: Optional[int] = None, sdk_retry_timeout_seconds: Optional[int] = None, multipart_upload_max_retries: Optional[int] = None, # In resumable upload, when replying to part upload request, server returns @@ -1490,13 +2318,19 @@ def __init__( # if abort is expected to be called expected_multipart_upload_aborted: bool = False, expected_single_shot_upload: bool = False, + expected_part_size: Optional[int] = None, ): super().__init__( name, stream_size, + cloud, overwrite, + source_type or [UploadSourceType.FILE, UploadSourceType.STREAM], + use_parallel + or [True, False], # Resumable Upload doesn't support parallel uploading of parts, but fallback should work + parallelism, multipart_upload_min_stream_size, - multipart_upload_chunk_size, + multipart_upload_part_size, sdk_retry_timeout_seconds, multipart_upload_max_retries, custom_response_on_single_shot_upload, @@ -1510,9 +2344,18 @@ def __init__( self.custom_response_on_upload = copy.deepcopy(custom_response_on_upload) self.custom_response_on_status_check = copy.deepcopy(custom_response_on_status_check) self.custom_response_on_abort = copy.deepcopy(custom_response_on_abort) + self.expected_exception_type = expected_exception_type + self.expected_part_size = expected_part_size def create_multipart_upload_server_state(self) -> ResumableUploadServerState: - return ResumableUploadServerState(self.unconfirmed_delta) + return ResumableUploadServerState(self.unconfirmed_delta, self.expected_part_size) + + def clear_state(self) -> None: + super().clear_state() + self.custom_response_on_create_resumable_url.clear_state() + self.custom_response_on_upload.clear_state() + self.custom_response_on_status_check.clear_state() + self.custom_response_on_abort.clear_state() def match_request_to_response( self, request: requests.Request, server_state: ResumableUploadServerState @@ -1650,6 +2493,15 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: expected_exception_type=PermissionDenied, expected_multipart_upload_aborted=False, # upload didn't start ), + ResumableUploadTestCase( + "Create resumable URL: fallback to single-shot upload when presigned URLs are disabled", + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse( + code=403, body=MultipartUploadTestCase.presigned_url_disabled_response, only_invocation=1 + ), + expected_multipart_upload_aborted=False, # upload didn't start + expected_single_shot_upload=True, + ), ResumableUploadTestCase( "Create resumable URL: 500 response is not retried", stream_size=1024 * 1024, @@ -1677,7 +2529,6 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: "Create resumable URL: permanent retryable status code", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=429), - sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) expected_exception_type=TimeoutError, expected_multipart_upload_aborted=False, # upload didn't start ), @@ -1739,7 +2590,7 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: ResumableUploadTestCase( "Upload: intermittent 429 response: retried", stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=7 * 1024 * 1024, + multipart_upload_part_size=7 * 1024 * 1024, multipart_upload_max_retries=3, custom_response_on_upload=CustomResponse( code=429, @@ -1752,7 +2603,7 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: ResumableUploadTestCase( "Upload: intermittent 429 response: retry exhausted", stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=1 * 1024 * 1024, + multipart_upload_part_size=1 * 1024 * 1024, multipart_upload_max_retries=3, custom_response_on_upload=CustomResponse( code=429, @@ -1770,7 +2621,7 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: # prevent part from being uploaded custom_response_on_upload=CustomResponse(code=403), # internal server error does not prevent server state change - custom_response_on_abort=CustomResponse(code=500), + custom_response_on_abort=CustomResponse(code=501), expected_exception_type=PermissionDenied, # abort returned error but was actually processed expected_multipart_upload_aborted=True, @@ -1788,31 +2639,34 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: ResumableUploadTestCase( "Multiple parts, zero unconfirmed delta", stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + multipart_upload_part_size=7 * 1024 * 1024 + 566, # server accepts all the parts in full unconfirmed_delta=0, expected_multipart_upload_aborted=False, + expected_part_size=7 * 1024 * 1024 + 566, # chunk size is used ), ResumableUploadTestCase( "Multiple small parts, zero unconfirmed delta", stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=100 * 1024, + multipart_upload_part_size=100 * 1024, # server accepts all the parts in full unconfirmed_delta=0, expected_multipart_upload_aborted=False, + expected_part_size=100 * 1024, # chunk size is used ), ResumableUploadTestCase( "Multiple parts, non-zero unconfirmed delta", stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + multipart_upload_part_size=7 * 1024 * 1024 + 566, # for every part, server accepts all except last 239 bytes unconfirmed_delta=239, expected_multipart_upload_aborted=False, + expected_part_size=7 * 1024 * 1024 + 566, # chunk size is used ), ResumableUploadTestCase( "Multiple parts, variable unconfirmed delta", stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + multipart_upload_part_size=7 * 1024 * 1024 + 566, # for the first part, server accepts all except last 15Kib # for the second part, server accepts it all # for the 3rd part, server accepts all except last 25000 bytes @@ -1820,6 +2674,7 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: # for the 5th part onwards server accepts all except last 5 bytes unconfirmed_delta=[15 * 1024, 0, 25000, 7 * 1024 * 1024, 5], expected_multipart_upload_aborted=False, + expected_part_size=7 * 1024 * 1024 + 566, # chunk size is used ), ResumableUploadTestCase( "Small stream, single-shot upload used", @@ -1833,3 +2688,70 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: ) def test_resumable_upload(config: Config, test_case: ResumableUploadTestCase) -> None: test_case.run(config) + + +@dataclass +class CreateDownloadUrlResponseTestCase: + data: Dict[str, Any] + expected_parsed_url: Optional[str] = None + expected_parsed_headers: Optional[Dict[str, str]] = None + expected_exception: Optional[Type[BaseException]] = None + + def run(self) -> None: + + if self.expected_exception: + with pytest.raises(self.expected_exception): + CreateDownloadUrlResponse.from_dict(self.data) + else: + response = CreateDownloadUrlResponse.from_dict(self.data) + assert response.url == self.expected_parsed_url + assert response.headers == (self.expected_parsed_headers or {}) + + def __str__(self) -> str: + return f"CreateDownloadUrlResponseTestCase(data={self.data})" + + +@pytest.mark.parametrize( + "test_case", + [ + CreateDownloadUrlResponseTestCase( + data={ + "url": "https://example.com/download", + "headers": [ + {"name": "Authorization", "value": "Bearer token123"}, + {"name": "Content-Type", "value": "application/octet-stream"}, + ], + }, + expected_parsed_url="https://example.com/download", + expected_parsed_headers={ + "Authorization": "Bearer token123", + "Content-Type": "application/octet-stream", + }, + ), + CreateDownloadUrlResponseTestCase( + data={"url": "https://example.com/download"}, + expected_parsed_url="https://example.com/download", + expected_parsed_headers={}, + ), + CreateDownloadUrlResponseTestCase( + data={"url": "https://example.com/download", "headers": []}, + expected_parsed_url="https://example.com/download", + expected_parsed_headers={}, + ), + CreateDownloadUrlResponseTestCase( + data={"headers": [{"name": "Content-Type", "value": "application/octet-stream"}]}, + expected_exception=ValueError, + ), + ], + ids=str, +) +def test_create_download_url_response(test_case: CreateDownloadUrlResponseTestCase) -> None: + """Run a test case for CreateDownloadUrlResponse.""" + test_case.run() + + +def fast_random_bytes(n: int, chunk_size: int = 1024) -> bytes: + # Generate a small random chunk + chunk = os.urandom(chunk_size) + # Repeat it until we reach n bytes + return (chunk * (n // chunk_size + 1))[:n] diff --git a/tests/test_files_utils.py b/tests/test_files_utils.py new file mode 100644 index 000000000..d11d91dbc --- /dev/null +++ b/tests/test_files_utils.py @@ -0,0 +1,316 @@ +import logging +import os +from abc import ABC, abstractmethod +from io import BytesIO, RawIOBase, UnsupportedOperation +from typing import BinaryIO, Callable, List, Optional, Tuple + +import pytest + +from databricks.sdk.mixins.files_utils import (_ConcatenatedInputStream, + _PresignedUrlDistributor) + +logger = logging.getLogger(__name__) + + +class Utils: + @staticmethod + def parse_range_header(range_header: str, content_length: Optional[int] = None) -> Tuple[int, int]: + """ + Parses a Range header string and returns the start and end byte positions. + Example input: "bytes=0-499" + Example output: (0, 499) + """ + if not range_header.startswith("bytes="): + raise ValueError("Invalid Range header format") + byte_range = range_header[len("bytes=") :] + start_str, end_str = byte_range.split("-") + start = int(start_str) if start_str else 0 + end = int(end_str) if end_str else None + + if content_length is not None: + if start >= content_length: + raise ValueError(f"Start byte {start} exceeds content length {content_length}") + if end is not None and end >= content_length: + raise ValueError(f"End byte {end} exceeds content length {content_length}") + if end is not None and start > end: + raise ValueError(f"Start byte {start} is greater than end byte {end}") + + return start, end + + +class NonSeekableBuffer(RawIOBase, BinaryIO): + """ + A non-seekable buffer that wraps a bytes object. Used for unit tests only. + This class implements the BinaryIO interface but does not support seeking. + It is used to simulate a non-seekable stream for testing purposes. + """ + + def __init__(self, data: bytes): + self._stream = BytesIO(data) + + def read(self, size: int = -1) -> bytes: + return self._stream.read(size) + + def readline(self, size: int = -1) -> bytes: + return self._stream.readline(size) + + def readlines(self, size: int = -1) -> List[bytes]: + return self._stream.readlines(size) + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def seek(self, *args, **kwargs) -> int: + raise UnsupportedOperation("seek not supported") + + def tell(self) -> int: + raise UnsupportedOperation("tell not supported") + + +class ConcatenatedInputStreamTestCase(ABC): + + @abstractmethod + def generate(self) -> Tuple[bytes, BinaryIO]: + pass + + +class ConcatenatedInputStreamTestCase(ConcatenatedInputStreamTestCase): + def __init__(self, head: bytes, tail: bytes, is_seekable: bool = True): + self._head = head + self._tail = tail + self._is_seekable = is_seekable + + def generate(self) -> Tuple[bytes, BinaryIO]: + """ + Generate a pair of: + (a) implementation under test + (b) concatenated byte array (to create reference implementation from) + """ + full_stream = self._head + self._tail + if self._is_seekable: + concatenated_stream = _ConcatenatedInputStream(BytesIO(self._head), BytesIO(self._tail)) + else: + concatenated_stream = _ConcatenatedInputStream(NonSeekableBuffer(self._head), NonSeekableBuffer(self._tail)) + return full_stream, concatenated_stream + + def test_to_string(self) -> str: + head = self._head.decode("utf-8") + tail = self._tail.decode("utf-8") + seekable = "seekable" if self._is_seekable else "non-seekable" + return f"{head}-{tail}-{seekable}" + + @staticmethod + def to_string(test_case) -> str: + return test_case.test_to_string() + + +test_cases = [ + ConcatenatedInputStreamTestCase(b"", b"zzzz"), + ConcatenatedInputStreamTestCase(b"", b""), + ConcatenatedInputStreamTestCase(b"", b"", is_seekable=False), + ConcatenatedInputStreamTestCase(b"foo", b"bar"), + ConcatenatedInputStreamTestCase(b"foo", b"bar", is_seekable=False), + ConcatenatedInputStreamTestCase(b"", b"zzzz", is_seekable=False), + ConcatenatedInputStreamTestCase(b"non_empty", b""), + ConcatenatedInputStreamTestCase(b"non_empty", b"", is_seekable=False), + ConcatenatedInputStreamTestCase(b"\n\n\n", b"\n\n"), + ConcatenatedInputStreamTestCase(b"\n\n\n", b"\n\n", is_seekable=False), + ConcatenatedInputStreamTestCase(b"aa\nbb\nccc\n", b"dd\nee\nff"), + ConcatenatedInputStreamTestCase(b"aa\nbb\nccc\n", b"dd\nee\nff", is_seekable=False), + ConcatenatedInputStreamTestCase(b"First line\nsecond line", b"first line with line \nbreak"), + ConcatenatedInputStreamTestCase(b"First line\nsecond line", b"first line with line \nbreak", is_seekable=False), + ConcatenatedInputStreamTestCase(b"First line\n", b"\nsecond line"), + ConcatenatedInputStreamTestCase(b"First line\n", b"\nsecond line", is_seekable=False), + ConcatenatedInputStreamTestCase(b"First line\n", b"\n"), + ConcatenatedInputStreamTestCase(b"First line\n", b"\n", is_seekable=False), + ConcatenatedInputStreamTestCase(b"First line\n", b""), + ConcatenatedInputStreamTestCase(b"First line\n", b"", is_seekable=False), + ConcatenatedInputStreamTestCase(b"", b"\nA line"), + ConcatenatedInputStreamTestCase(b"", b"\nA line", is_seekable=False), + ConcatenatedInputStreamTestCase(b"\n", b"\nA line"), + ConcatenatedInputStreamTestCase(b"\n", b"\nA line", is_seekable=False), +] + + +def verify(test_case: ConcatenatedInputStreamTestCase, apply: Callable[[BinaryIO], Tuple[any, bool]]): + """ + This method applies given function iteratively to both implementation under test + and reference implementation of the stream, and verifies the result on each step is identical. + """ + result_bytes, implementation_under_test = test_case.generate() + reference_implementation = BytesIO(result_bytes) + + while True: + expected = apply(reference_implementation) + actual = apply(implementation_under_test) + + assert actual == expected + + should_stop = actual[1] + if should_stop: + break + + if len(result_bytes) == reference_implementation.tell(): + verify_eof(implementation_under_test) + verify_eof(reference_implementation) + + +def verify_eof(buffer: BinaryIO): + assert len(buffer.read()) == 0 + assert len(buffer.read(100)) == 0 + assert len(buffer.readline()) == 0 + assert len(buffer.readline(100)) == 0 + assert len(buffer.readlines()) == 0 + assert len(buffer.readlines(100)) == 0 + + +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize("limit", [-1, 0, 1, 3, 4, 5, 6, 10, 100, 1000]) +def test_read(config, test_case: ConcatenatedInputStreamTestCase, limit: int): + def apply(buffer: BinaryIO): + value = buffer.read(limit) + + if limit > 0: + assert len(value) <= limit + + should_stop = (limit > 0 and len(value) < limit) or len(value) == 0 + return value, should_stop + + verify(test_case, apply) + + +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize("limit", [-1, 0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 100, 1000]) +def test_read_line(config, test_case: ConcatenatedInputStreamTestCase, limit: int): + def apply(buffer: BinaryIO): + value = buffer.readline(limit) + should_stop = len(value) == 0 + return value, should_stop + + verify(test_case, apply) + + +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize("limit", [-1, 0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 100, 1000]) +def test_read_lines(config, test_case: ConcatenatedInputStreamTestCase, limit: int): + def apply(buffer: BinaryIO): + value = buffer.readlines(limit) + should_stop = len(value) == 0 + return value, should_stop + + verify(test_case, apply) + + +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +def test_iterator(config, test_case: ConcatenatedInputStreamTestCase): + def apply(buffer: BinaryIO): + try: + value = buffer.__next__() + return value, False + except StopIteration: + return None, True + + verify(test_case, apply) + + +def seeks_to_string(seeks: [Tuple[int, int]]): + ", ".join(list(map(lambda seek: f"Seek: offset={seek[0]}, whence={seek[1]}", seeks))) + + +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize( + "seeks", + [ + [(0, os.SEEK_SET)], + [(-10, os.SEEK_SET), (1, os.SEEK_SET)], + [(10000, os.SEEK_SET)], + [(0, os.SEEK_END)], + [(10000, os.SEEK_END)], + [(-10000, os.SEEK_END)], + [(1, os.SEEK_SET)], + [(5, os.SEEK_SET)], + [(-1, os.SEEK_END)], + [(-1, os.SEEK_CUR)], + [(-100, os.SEEK_CUR), (105, os.SEEK_CUR), (2, os.SEEK_CUR), (-2, os.SEEK_CUR)], + ], + ids=seeks_to_string, +) +def test_seek(config, test_case: ConcatenatedInputStreamTestCase, seeks: List[Tuple[int, int]]): + def read_and_restore(buf: BinaryIO) -> bytes: + pos = buf.tell() + result = buf.read() + buf.seek(pos) + return result + + def safe_call(buf: BinaryIO, call: Callable[[BinaryIO], any]) -> (any, bool): + """ + Calls the provided function on the buffer and returns the result. + It is a wrapper to handle exceptions gracefully. + If an exception occurs, it returns None and False. + :param buf: The buffer to operate on. + :param call: The function to call with the buffer. + :return: A tuple of (result, success), where success is True if the call succeeded, False otherwise. + """ + try: + result = call(buf) + return result, True + except Exception: + return None, False + + underlying, buffer = test_case.generate() + native_buffer = BytesIO(underlying) + if not buffer.seekable(): + return + + assert buffer.tell() == native_buffer.tell() + for seek in seeks: + do_seek = lambda buf: buf.seek(seek[0], seek[1]) + assert safe_call(buffer, do_seek) == safe_call(native_buffer, do_seek) + assert buffer.tell() == native_buffer.tell() + assert read_and_restore(buffer) == read_and_restore(native_buffer) + + +class DummyResponse: + def __init__(self, value): + self.value = value + + +def test_get_url_returns_url_and_version(): + distributor = _PresignedUrlDistributor(lambda: DummyResponse("url1")) + url, version = distributor.get_url() + assert isinstance(url, DummyResponse) + assert url.value == "url1" + assert version == 0 + + +def test_get_url_caches_url(): + calls = [] + distributor = _PresignedUrlDistributor(lambda: calls.append(1) or DummyResponse("url2")) + url1, version1 = distributor.get_url() + url2, version2 = distributor.get_url() + assert url1 is url2 + assert version1 == version2 + assert calls.count(1) == 1 # Only called once + + +def test_invalidate_url_changes_url_and_version(): + responses = [DummyResponse("urlA"), DummyResponse("urlB")] + distributor = _PresignedUrlDistributor(lambda: responses.pop(0)) + url1, version1 = distributor.get_url() + distributor.invalidate_url(version1) + url2, version2 = distributor.get_url() + assert url1.value == "urlA" + assert url2.value == "urlB" + assert version2 == version1 + 1 + + +def test_invalidate_url_wrong_version_does_not_invalidate(): + distributor = _PresignedUrlDistributor(lambda: DummyResponse("urlX")) + url1, version1 = distributor.get_url() + distributor.invalidate_url(version1 + 1) # Wrong version + url2, version2 = distributor.get_url() + assert url1 is url2 + assert version2 == version1