diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 819e14d4b..63260a59d 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -964,6 +964,8 @@ def _parallel_download_presigned_url(self, remote_path: str, destination: str, p cloud_session = self._create_cloud_provider_session() url_distributor = _PresignedUrlDistributor(lambda: self._create_download_url(remote_path)) + # An event to indicate if any download chunk has succeeded. If any chunk succeeds, we do not fall back to Files API. + any_success = Event() def download_chunk(additional_headers: dict[str, str]) -> BinaryIO: retry_count = 0 @@ -982,10 +984,14 @@ def get_content() -> requests.Response: url_distributor.invalidate_url(version) retry_count += 1 continue - elif raw_resp.status_code == 403: + elif raw_resp.status_code == 403 and not any_success.is_set(): raise FallbackToDownloadUsingFilesApi("Received 403 Forbidden from presigned URL") + elif not any_success.is_set(): + # For other errors, we raise a retryable exception to trigger retry logic. + raise FallbackToDownloadUsingFilesApi(f"Received {raw_resp.status_code} from presigned URL") raw_resp.raise_for_status() + any_success.set() return BytesIO(raw_resp.content) raise ValueError("Exceeded maximum retries for downloading with presigned URL: URL expired too many times") @@ -1404,20 +1410,47 @@ def _parallel_multipart_upload_from_file( part_size = ctx.part_size num_parts = (file_size + part_size - 1) // part_size _LOG.debug(f"Uploading file of size {file_size} bytes in {num_parts} parts using {ctx.parallelism} threads") + cloud_provider_session = self._create_cloud_provider_session() + + # Upload one part to verify the upload can proceed. + with open(ctx.source_file_path, "rb") as f: + f.seek(0) + first_part_size = min(part_size, file_size) + first_part_buffer = f.read(first_part_size) + try: + etag = self._do_upload_one_part( + ctx, + cloud_provider_session, + 1, + 0, + first_part_size, + session_token, + BytesIO(first_part_buffer), + is_first_part=True, + ) + except FallbackToUploadUsingFilesApi as e: + raise FallbackToUploadUsingFilesApi(None, "Falling back to single-shot upload with Files API") from e + if num_parts == 1: + self._complete_multipart_upload(ctx, {1: etag}, session_token) + return # Create queues and worker threads. task_queue = Queue() etags_result_queue = Queue() + etags_result_queue.put_nowait((1, etag)) exception_queue = Queue() aborted = Event() workers = [ - Thread(target=self._upload_file_consumer, args=(task_queue, etags_result_queue, exception_queue, aborted)) + Thread( + target=self._upload_file_consumer, + args=(cloud_provider_session, task_queue, etags_result_queue, exception_queue, aborted), + ) for _ in range(ctx.parallelism) ] _LOG.debug(f"Starting {len(workers)} worker threads for parallel upload") # Enqueue all parts. Since the task queue is populated before starting the workers, we don't need to signal completion. - for part_index in range(1, num_parts + 1): + for part_index in range(2, num_parts + 1): part_offset = (part_index - 1) * part_size part_size = min(part_size, file_size - part_offset) part = self._MultipartUploadPart(ctx, part_index, part_offset, part_size, session_token) @@ -1466,7 +1499,14 @@ def _parallel_multipart_upload_from_stream( ) try: etag = self._do_upload_one_part( - ctx, cloud_provider_session, 1, 0, len(pre_read_buffer), session_token, BytesIO(pre_read_buffer) + ctx, + cloud_provider_session, + 1, + 0, + len(pre_read_buffer), + session_token, + BytesIO(pre_read_buffer), + is_first_part=True, ) etags_result_queue.put((1, etag)) except FallbackToUploadUsingFilesApi as e: @@ -1551,12 +1591,12 @@ def _complete_multipart_upload(self, ctx, etags, session_token): def _upload_file_consumer( self, + cloud_provider_session: requests.Session, task_queue: Queue[FilesExt._MultipartUploadPart], etags_queue: Queue[tuple[int, str]], exception_queue: Queue[Exception], aborted: Event, ) -> None: - cloud_provider_session = self._create_cloud_provider_session() while not aborted.is_set(): try: part = task_queue.get(block=False) @@ -1627,6 +1667,7 @@ def _do_upload_one_part( part_size: int, session_token: str, part_content: BinaryIO, + is_first_part: bool = False, ) -> str: retry_count = 0 @@ -1648,18 +1689,14 @@ def _do_upload_one_part( upload_part_urls_response = self._api.do( "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body ) - except PermissionDenied as e: - if self._is_presigned_urls_disabled_error(e): - raise FallbackToUploadUsingFilesApi(None, "Presigned URLs are disabled") - else: - raise e from None - except InternalError as e: - if self._is_presigned_urls_network_zone_error(e): + except Exception as e: + if is_first_part: raise FallbackToUploadUsingFilesApi( - None, "Presigned URLs are not supported in the current network zone" + None, + f"Failed to obtain upload URL for part {part_index}: {e}, falling back to single shot upload", ) else: - raise e from None + raise e upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) if len(upload_part_urls) == 0: @@ -1699,8 +1736,11 @@ def perform_upload() -> requests.Response: continue else: raise ValueError(f"Unsuccessful chunk upload: upload URL expired after {retry_count} retries") - elif upload_response.status_code == 403: + elif upload_response.status_code == 403 and is_first_part: raise FallbackToUploadUsingFilesApi(None, f"Direct upload forbidden: {upload_response.content}") + elif is_first_part: + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + raise FallbackToUploadUsingFilesApi(None, message) else: message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" _LOG.warning(message) @@ -1765,18 +1805,13 @@ def _perform_multipart_upload( upload_part_urls_response = self._api.do( "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body ) - except PermissionDenied as e: - if chunk_offset == 0 and self._is_presigned_urls_disabled_error(e): - raise FallbackToUploadUsingFilesApi(buffer, "Presigned URLs are disabled") - else: - raise e from None - except InternalError as e: - if chunk_offset == 0 and self._is_presigned_urls_network_zone_error(e): + except Exception as e: + if chunk_offset == 0: raise FallbackToUploadUsingFilesApi( - buffer, "Presigned URLs are not supported in the current network zone" - ) + buffer, f"Failed to obtain upload URLs: {e}, falling back to single shot upload" + ) from e else: - raise e from None + raise e upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) if len(upload_part_urls) == 0: @@ -1847,7 +1882,14 @@ def perform(): # Let's fallback to using Files API which might be allowlisted to upload, passing # currently buffered (but not yet uploaded) part of the stream. raise FallbackToUploadUsingFilesApi(buffer, f"Direct upload forbidden: {upload_response.content}") - + elif chunk_offset == 0: + # We got an upload failure when uploading the very first chunk. + # Let's fallback to using Files API which might be more reliable in this case, + # passing currently buffered (but not yet uploaded) part of the stream. + raise FallbackToUploadUsingFilesApi( + buffer, + f"Unsuccessful chunk upload: {upload_response.status_code}, falling back to single shot upload", + ) else: message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" _LOG.warning(message) @@ -1985,18 +2027,10 @@ def _perform_resumable_upload( resumable_upload_url_response = self._api.do( "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body ) - except PermissionDenied as e: - if self._is_presigned_urls_disabled_error(e): - raise FallbackToUploadUsingFilesApi(pre_read_buffer, "Presigned URLs are disabled") - else: - raise e from None - except InternalError as e: - if self._is_presigned_urls_network_zone_error(e): - raise FallbackToUploadUsingFilesApi( - pre_read_buffer, "Presigned URLs are not supported in the current network zone" - ) - else: - raise e from None + except Exception as e: + raise FallbackToUploadUsingFilesApi( + pre_read_buffer, f"Failed to obtain resumable upload URL: {e}, falling back to single shot upload" + ) from e resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url") if not resumable_upload_url_node: @@ -2376,16 +2410,8 @@ def _create_download_url(self, file_path: str) -> CreateDownloadUrlResponse: ) return CreateDownloadUrlResponse.from_dict(raw_response) - except PermissionDenied as e: - if self._is_presigned_urls_disabled_error(e): - raise FallbackToDownloadUsingFilesApi(f"Presigned URLs are disabled") - else: - raise e from None - except InternalError as e: - if self._is_presigned_urls_network_zone_error(e): - raise FallbackToDownloadUsingFilesApi("Presigned URLs are not supported in the current network zone") - else: - raise e from None + except Exception as e: + raise FallbackToDownloadUsingFilesApi(f"Failed to create download URL: {e}") from e def _init_download_response_presigned_api(self, file_path: str, added_headers: dict[str, str]) -> DownloadResponse: """ @@ -2428,17 +2454,11 @@ def perform() -> requests.Response: contents=_StreamingResponse(csp_response, self._config.files_ext_client_download_streaming_chunk_size), ) return resp - elif csp_response.status_code == 403: - # We got 403 failure when downloading the file. This might happen due to Azure firewall enabled for the customer bucket. - # Let's fallback to using Files API which might be allowlisted to download. - raise FallbackToDownloadUsingFilesApi(f"Direct download forbidden: {csp_response.content}") else: message = ( f"Unsuccessful download. Response status: {csp_response.status_code}, body: {csp_response.content}" ) - _LOG.warning(message) - mapped_error = _error_mapper(csp_response, {}) - raise mapped_error or ValueError(message) + raise FallbackToDownloadUsingFilesApi(message) def _init_download_response_mode_csp_with_fallback( self, file_path: str, headers: dict[str, str], response_headers: list[str] diff --git a/tests/test_files.py b/tests/test_files.py index 64947e9b8..7ad6c21c4 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1066,14 +1066,14 @@ def run(self, config: Config, monkeypatch) -> None: PresignedUrlDownloadTestCase( name="Presigned URL download fails with 403", file_size=100 * 1024 * 1024, - expected_exception_type=PermissionDenied, custom_response_create_presigned_url=CustomResponse(code=403, only_invocation=1), + expected_download_api="files_api", ), PresignedUrlDownloadTestCase( name="Presigned URL download fails with 500 when creating presigned URL", file_size=100 * 1024 * 1024, - expected_exception_type=InternalError, custom_response_create_presigned_url=CustomResponse(code=500, only_invocation=1), + expected_download_api="files_api", ), PresignedUrlDownloadTestCase( name="Presigned URL download fails with 500 when downloading from URL", @@ -1921,40 +1921,40 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: ), # -------------------------- failures on "create upload URL" -------------------------- MultipartUploadTestCase( - "Create upload URL: 400 response is not retried", + "Create upload URL: 400 response should fallback", content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( code=400, # 1 failure is enough only_invocation=1, ), - expected_exception_type=BadRequest, expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, ), MultipartUploadTestCase( - "Create upload URL: 403 response is not retried", + "Create upload URL: 403 response should fallback", content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( code=403, # 1 failure is enough only_invocation=1, ), - expected_exception_type=PermissionDenied, expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, ), MultipartUploadTestCase( - "Create upload URL: internal error is not retried", + "Create upload URL: internal error should fallback", content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), - expected_exception_type=InternalError, expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, ), MultipartUploadTestCase( - "Create upload URL: non-JSON response is not retried", + "Create upload URL: non-JSON response should fallback", content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(body="this is not a JSON", only_invocation=1), - expected_exception_type=requests.exceptions.JSONDecodeError, expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, ), MultipartUploadTestCase( "Create upload URL: meaningless JSON response is not retried", @@ -1980,11 +1980,11 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: expected_multipart_upload_aborted=True, ), MultipartUploadTestCase( - "Create upload URL: permanent retryable exception", + "Create upload URL: permanent retryable exception should fallback", content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(exception=requests.ConnectionError), - expected_exception_type=TimeoutError, expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, ), MultipartUploadTestCase( "Create upload URL: intermittent retryable exception", @@ -2221,8 +2221,8 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: content_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1), custom_response_on_create_abort_url=CustomResponse(code=429, first_invocation=1, last_invocation=3), - expected_exception_type=PermissionDenied, # original error expected_multipart_upload_aborted=True, # abort successfully called after abort URL creation is retried + expected_single_shot_upload=True, ), MultipartUploadTestCase( "Abort: exception", @@ -2233,12 +2233,12 @@ def to_string(test_case: "MultipartUploadTestCase") -> str: # this allows to change the server state to "aborted" exception_happened_before_processing=False, ), - expected_exception_type=PermissionDenied, # original error is reported expected_multipart_upload_aborted=True, + expected_single_shot_upload=True, ), # -------------------------- Parallel Upload for Streams -------------------------- MultipartUploadTestCase( - "Multipart parallel upload for stream: Upload errors are not retried", + "Multipart parallel upload for stream: Upload errors are not retried but fallback", content_size=10 * 1024 * 1024, multipart_upload_part_size=1024 * 1024, source_type=[UploadSourceType.SEEKABLE_STREAM], @@ -2557,22 +2557,22 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: [ # ------------------ failures on creating resumable upload URL ------------------ ResumableUploadTestCase( - "Create resumable URL: 400 response is not retried", + "Create resumable URL: 400 response is not retried and should fallback", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse( code=400, # 1 failure is enough only_invocation=1, ), - expected_exception_type=BadRequest, expected_multipart_upload_aborted=False, # upload didn't start + expected_single_shot_upload=True, ), ResumableUploadTestCase( - "Create resumable URL: 403 response is not retried", + "Create resumable URL: 403 response is not retried and should fallback", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=403, only_invocation=1), - expected_exception_type=PermissionDenied, expected_multipart_upload_aborted=False, # upload didn't start + expected_single_shot_upload=True, ), ResumableUploadTestCase( "Create resumable URL: fallback to single-shot upload when presigned URLs are disabled", @@ -2595,18 +2595,18 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: expected_single_shot_upload=True, ), ResumableUploadTestCase( - "Create resumable URL: 500 response is not retried", + "Create resumable URL: 500 response is not retried and should fallback", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=500, only_invocation=1), - expected_exception_type=InternalError, expected_multipart_upload_aborted=False, # upload didn't start + expected_single_shot_upload=True, ), ResumableUploadTestCase( "Create resumable URL: non-JSON response is not retried", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(body="Foo bar", only_invocation=1), - expected_exception_type=requests.exceptions.JSONDecodeError, expected_multipart_upload_aborted=False, # upload didn't start + expected_single_shot_upload=True, ), ResumableUploadTestCase( "Create resumable URL: meaningless JSON response is not retried", @@ -2621,8 +2621,8 @@ def to_string(test_case: "ResumableUploadTestCase") -> str: "Create resumable URL: permanent retryable status code", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=429), - expected_exception_type=TimeoutError, expected_multipart_upload_aborted=False, # upload didn't start + expected_single_shot_upload=True, ), ResumableUploadTestCase( "Create resumable URL: intermittent retryable exception is retried",