Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 75 additions & 55 deletions databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,8 @@ def _parallel_download_presigned_url(self, remote_path: str, destination: str, p

cloud_session = self._create_cloud_provider_session()
url_distributor = _PresignedUrlDistributor(lambda: self._create_download_url(remote_path))
# An event to indicate if any download chunk has succeeded. If any chunk succeeds, we do not fall back to Files API.
any_success = Event()

def download_chunk(additional_headers: dict[str, str]) -> BinaryIO:
retry_count = 0
Expand All @@ -982,10 +984,14 @@ def get_content() -> requests.Response:
url_distributor.invalidate_url(version)
retry_count += 1
continue
elif raw_resp.status_code == 403:
elif raw_resp.status_code == 403 and not any_success.is_set():
raise FallbackToDownloadUsingFilesApi("Received 403 Forbidden from presigned URL")
elif not any_success.is_set():
# For other errors, we raise a retryable exception to trigger retry logic.
raise FallbackToDownloadUsingFilesApi(f"Received {raw_resp.status_code} from presigned URL")

raw_resp.raise_for_status()
any_success.set()
return BytesIO(raw_resp.content)
raise ValueError("Exceeded maximum retries for downloading with presigned URL: URL expired too many times")

Expand Down Expand Up @@ -1404,20 +1410,47 @@ def _parallel_multipart_upload_from_file(
part_size = ctx.part_size
num_parts = (file_size + part_size - 1) // part_size
_LOG.debug(f"Uploading file of size {file_size} bytes in {num_parts} parts using {ctx.parallelism} threads")
cloud_provider_session = self._create_cloud_provider_session()

# Upload one part to verify the upload can proceed.
with open(ctx.source_file_path, "rb") as f:
f.seek(0)
first_part_size = min(part_size, file_size)
first_part_buffer = f.read(first_part_size)
try:
etag = self._do_upload_one_part(
ctx,
cloud_provider_session,
1,
0,
first_part_size,
session_token,
BytesIO(first_part_buffer),
is_first_part=True,
)
except FallbackToUploadUsingFilesApi as e:
raise FallbackToUploadUsingFilesApi(None, "Falling back to single-shot upload with Files API") from e
if num_parts == 1:
self._complete_multipart_upload(ctx, {1: etag}, session_token)
return

# Create queues and worker threads.
task_queue = Queue()
etags_result_queue = Queue()
etags_result_queue.put_nowait((1, etag))
exception_queue = Queue()
aborted = Event()
workers = [
Thread(target=self._upload_file_consumer, args=(task_queue, etags_result_queue, exception_queue, aborted))
Thread(
target=self._upload_file_consumer,
args=(cloud_provider_session, task_queue, etags_result_queue, exception_queue, aborted),
)
for _ in range(ctx.parallelism)
]
_LOG.debug(f"Starting {len(workers)} worker threads for parallel upload")

# Enqueue all parts. Since the task queue is populated before starting the workers, we don't need to signal completion.
for part_index in range(1, num_parts + 1):
for part_index in range(2, num_parts + 1):
part_offset = (part_index - 1) * part_size
part_size = min(part_size, file_size - part_offset)
part = self._MultipartUploadPart(ctx, part_index, part_offset, part_size, session_token)
Expand Down Expand Up @@ -1466,7 +1499,14 @@ def _parallel_multipart_upload_from_stream(
)
try:
etag = self._do_upload_one_part(
ctx, cloud_provider_session, 1, 0, len(pre_read_buffer), session_token, BytesIO(pre_read_buffer)
ctx,
cloud_provider_session,
1,
0,
len(pre_read_buffer),
session_token,
BytesIO(pre_read_buffer),
is_first_part=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to pass this information? Every error returned by this will fall back to FilesAPI, right?

)
etags_result_queue.put((1, etag))
except FallbackToUploadUsingFilesApi as e:
Expand Down Expand Up @@ -1551,12 +1591,12 @@ def _complete_multipart_upload(self, ctx, etags, session_token):

def _upload_file_consumer(
self,
cloud_provider_session: requests.Session,
task_queue: Queue[FilesExt._MultipartUploadPart],
etags_queue: Queue[tuple[int, str]],
exception_queue: Queue[Exception],
aborted: Event,
) -> None:
cloud_provider_session = self._create_cloud_provider_session()
while not aborted.is_set():
try:
part = task_queue.get(block=False)
Expand Down Expand Up @@ -1627,6 +1667,7 @@ def _do_upload_one_part(
part_size: int,
session_token: str,
part_content: BinaryIO,
is_first_part: bool = False,
) -> str:
retry_count = 0

Expand All @@ -1648,18 +1689,14 @@ def _do_upload_one_part(
upload_part_urls_response = self._api.do(
"POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body
)
except PermissionDenied as e:
if self._is_presigned_urls_disabled_error(e):
raise FallbackToUploadUsingFilesApi(None, "Presigned URLs are disabled")
else:
raise e from None
except InternalError as e:
if self._is_presigned_urls_network_zone_error(e):
except Exception as e:
if is_first_part:
raise FallbackToUploadUsingFilesApi(
None, "Presigned URLs are not supported in the current network zone"
None,
f"Failed to obtain upload URL for part {part_index}: {e}, falling back to single shot upload",
)
else:
raise e from None
raise e

upload_part_urls = upload_part_urls_response.get("upload_part_urls", [])
if len(upload_part_urls) == 0:
Expand Down Expand Up @@ -1699,8 +1736,11 @@ def perform_upload() -> requests.Response:
continue
else:
raise ValueError(f"Unsuccessful chunk upload: upload URL expired after {retry_count} retries")
elif upload_response.status_code == 403:
elif upload_response.status_code == 403 and is_first_part:
raise FallbackToUploadUsingFilesApi(None, f"Direct upload forbidden: {upload_response.content}")
elif is_first_part:
message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}"
raise FallbackToUploadUsingFilesApi(None, message)
else:
message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}"
_LOG.warning(message)
Expand Down Expand Up @@ -1765,18 +1805,13 @@ def _perform_multipart_upload(
upload_part_urls_response = self._api.do(
"POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body
)
except PermissionDenied as e:
if chunk_offset == 0 and self._is_presigned_urls_disabled_error(e):
raise FallbackToUploadUsingFilesApi(buffer, "Presigned URLs are disabled")
else:
raise e from None
except InternalError as e:
if chunk_offset == 0 and self._is_presigned_urls_network_zone_error(e):
except Exception as e:
if chunk_offset == 0:
raise FallbackToUploadUsingFilesApi(
buffer, "Presigned URLs are not supported in the current network zone"
)
buffer, f"Failed to obtain upload URLs: {e}, falling back to single shot upload"
) from e
else:
raise e from None
raise e

upload_part_urls = upload_part_urls_response.get("upload_part_urls", [])
if len(upload_part_urls) == 0:
Expand Down Expand Up @@ -1847,7 +1882,14 @@ def perform():
# Let's fallback to using Files API which might be allowlisted to upload, passing
# currently buffered (but not yet uploaded) part of the stream.
raise FallbackToUploadUsingFilesApi(buffer, f"Direct upload forbidden: {upload_response.content}")

elif chunk_offset == 0:
# We got an upload failure when uploading the very first chunk.
# Let's fallback to using Files API which might be more reliable in this case,
# passing currently buffered (but not yet uploaded) part of the stream.
raise FallbackToUploadUsingFilesApi(
buffer,
f"Unsuccessful chunk upload: {upload_response.status_code}, falling back to single shot upload",
)
else:
message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}"
_LOG.warning(message)
Expand Down Expand Up @@ -1985,18 +2027,10 @@ def _perform_resumable_upload(
resumable_upload_url_response = self._api.do(
"POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body
)
except PermissionDenied as e:
if self._is_presigned_urls_disabled_error(e):
raise FallbackToUploadUsingFilesApi(pre_read_buffer, "Presigned URLs are disabled")
else:
raise e from None
except InternalError as e:
if self._is_presigned_urls_network_zone_error(e):
raise FallbackToUploadUsingFilesApi(
pre_read_buffer, "Presigned URLs are not supported in the current network zone"
)
else:
raise e from None
except Exception as e:
raise FallbackToUploadUsingFilesApi(
pre_read_buffer, f"Failed to obtain resumable upload URL: {e}, falling back to single shot upload"
) from e

resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url")
if not resumable_upload_url_node:
Expand Down Expand Up @@ -2376,16 +2410,8 @@ def _create_download_url(self, file_path: str) -> CreateDownloadUrlResponse:
)

return CreateDownloadUrlResponse.from_dict(raw_response)
except PermissionDenied as e:
if self._is_presigned_urls_disabled_error(e):
raise FallbackToDownloadUsingFilesApi(f"Presigned URLs are disabled")
else:
raise e from None
except InternalError as e:
if self._is_presigned_urls_network_zone_error(e):
raise FallbackToDownloadUsingFilesApi("Presigned URLs are not supported in the current network zone")
else:
raise e from None
except Exception as e:
raise FallbackToDownloadUsingFilesApi(f"Failed to create download URL: {e}") from e

def _init_download_response_presigned_api(self, file_path: str, added_headers: dict[str, str]) -> DownloadResponse:
"""
Expand Down Expand Up @@ -2428,17 +2454,11 @@ def perform() -> requests.Response:
contents=_StreamingResponse(csp_response, self._config.files_ext_client_download_streaming_chunk_size),
)
return resp
elif csp_response.status_code == 403:
# We got 403 failure when downloading the file. This might happen due to Azure firewall enabled for the customer bucket.
# Let's fallback to using Files API which might be allowlisted to download.
raise FallbackToDownloadUsingFilesApi(f"Direct download forbidden: {csp_response.content}")
else:
message = (
f"Unsuccessful download. Response status: {csp_response.status_code}, body: {csp_response.content}"
)
_LOG.warning(message)
mapped_error = _error_mapper(csp_response, {})
raise mapped_error or ValueError(message)
raise FallbackToDownloadUsingFilesApi(message)

def _init_download_response_mode_csp_with_fallback(
self, file_path: str, headers: dict[str, str], response_headers: list[str]
Expand Down
Loading
Loading