Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
### Bug Fixes
- Fix `FilesExt` can fail to upload and download data when Presigned URLs are not available in certain environments (e.g. Serverless GPU clusters).

- Fix `FilesExt.upload` and `FilesExt.upload_from` would fail when the source content is empty and `use_parallel=True`.

### Documentation

### Internal Changes
Expand Down
11 changes: 7 additions & 4 deletions databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,9 @@ def upload(
f"Upload context: part_size={ctx.part_size}, batch_size={ctx.batch_size}, content_length={ctx.content_length}"
)

if ctx.use_parallel:
if ctx.use_parallel and (
ctx.content_length is None or ctx.content_length >= self._config.files_ext_multipart_upload_min_stream_size
):
self._parallel_upload_from_stream(ctx, contents)
return UploadStreamResult()
elif ctx.content_length is not None:
Expand Down Expand Up @@ -1206,7 +1208,7 @@ def upload_from(
use_parallel=use_parallel,
parallelism=parallelism,
)
if ctx.use_parallel:
if ctx.use_parallel and ctx.content_length >= self._config.files_ext_multipart_upload_min_stream_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can content_length be none here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in the upload_from function, content_length will always be set using os.path.getsize(source_path), which only returns a number.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, not in this PR. Let's avoid using the same uploadContext in both methods, as they may differ for different functions.
FWIW, I am planning to add a static type checker in the Python SDK, which will block such scenarios.

self._parallel_upload_from_file(ctx)
return UploadFileResult()
else:
Expand Down Expand Up @@ -1459,8 +1461,9 @@ def _parallel_multipart_upload_from_stream(
# Do the first part read ahead
pre_read_buffer = content.read(ctx.part_size)
if not pre_read_buffer:
self._complete_multipart_upload(ctx, {}, session_token)
return
raise FallbackToUploadUsingFilesApi(
b"", "Falling back to single-shot upload with Files API due to empty input stream"
)
try:
etag = self._do_upload_one_part(
ctx, cloud_provider_session, 1, 0, len(pre_read_buffer), session_token, BytesIO(pre_read_buffer)
Expand Down
49 changes: 35 additions & 14 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from databricks.sdk.mixins.files_utils import CreateDownloadUrlResponse
from tests.clock import FakeClock

from .test_files_utils import Utils
from .test_files_utils import NonSeekableBuffer, Utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1286,8 +1286,8 @@ def __init__(
custom_response_on_single_shot_upload: CustomResponse,
# exception which is expected to be thrown (so upload is expected to have failed)
expected_exception_type: Optional[Type[BaseException]],
# if abort is expected to be called for multipart/resumable upload
expected_multipart_upload_aborted: bool,
# Whether abort is expected to be called for multipart/resumable upload, set to None if we don't care.
expected_multipart_upload_aborted: Optional[bool],
expected_single_shot_upload: bool,
):
self.name = name
Expand All @@ -1303,7 +1303,7 @@ def __init__(
self.multipart_upload_max_retries = multipart_upload_max_retries
self.custom_response_on_single_shot_upload = custom_response_on_single_shot_upload
self.expected_exception_type = expected_exception_type
self.expected_multipart_upload_aborted: bool = expected_multipart_upload_aborted
self.expected_multipart_upload_aborted: Optional[bool] = expected_multipart_upload_aborted
self.expected_single_shot_upload = expected_single_shot_upload

self.path = "/test.txt"
Expand All @@ -1323,16 +1323,20 @@ def clear_state(self) -> None:
logger.warning("Failed to remove temp file: %s", file_path)
self.created_temp_files = []

def get_upload_file(self, content: bytes, source_type: "UploadSourceType") -> Union[str, io.BytesIO]:
def get_upload_file(
self, content: bytes, source_type: "UploadSourceType"
) -> Union[str, io.BytesIO, NonSeekableBuffer]:
"""Returns a file or stream to upload based on the source type."""
if source_type == UploadSourceType.FILE:
with NamedTemporaryFile(mode="wb", delete=False) as f:
f.write(content)
file_path = f.name
self.created_temp_files.append(file_path)
return file_path
elif source_type == UploadSourceType.STREAM:
elif source_type == UploadSourceType.SEEKABLE_STREAM:
return io.BytesIO(content)
elif source_type == UploadSourceType.NONSEEKABLE_STREAM:
return NonSeekableBuffer(content)
else:
raise ValueError(f"Unknown source type: {source_type}")

Expand Down Expand Up @@ -1446,7 +1450,8 @@ def upload() -> None:
), "Single-shot upload should not have succeeded"

assert (
multipart_server_state.aborted == self.expected_multipart_upload_aborted
self.expected_multipart_upload_aborted is None
or multipart_server_state.aborted == self.expected_multipart_upload_aborted
), "Multipart upload aborted state mismatch"

finally:
Expand All @@ -1462,7 +1467,8 @@ class UploadSourceType(Enum):
"""Source type for the upload. Used to determine how to upload the file."""

FILE = "file" # upload from a file on disk
STREAM = "stream" # upload from a stream (e.g. BytesIO)
SEEKABLE_STREAM = "seekable_stream" # upload from a seekable stream (e.g. BytesIO)
NONSEEKABLE_STREAM = "nonseekable_stream" # upload from a non-seekable stream (e.g. network stream)


class MultipartUploadTestCase(UploadTestCase):
Expand Down Expand Up @@ -1570,15 +1576,16 @@ def __init__(
# if abort is expected to be called
# expected part size
expected_part_size: Optional[int] = None,
expected_multipart_upload_aborted: bool = False,
expected_multipart_upload_aborted: Optional[bool] = False,
expected_single_shot_upload: bool = False,
):
super().__init__(
name,
content_size,
cloud,
overwrite,
source_type or [UploadSourceType.FILE, UploadSourceType.STREAM],
source_type
or [UploadSourceType.FILE, UploadSourceType.SEEKABLE_STREAM, UploadSourceType.NONSEEKABLE_STREAM],
use_parallel or [False, True],
parallelism,
multipart_upload_min_stream_size,
Expand Down Expand Up @@ -1710,6 +1717,7 @@ def processor() -> list:
request_json = request.json()
etags = {}

assert len(request_json["parts"]) > 0
for part in request_json["parts"]:
etags[part["part_number"]] = part["etag"]

Expand Down Expand Up @@ -1786,10 +1794,22 @@ def to_string(test_case: "MultipartUploadTestCase") -> str:
[
# -------------------------- happy cases --------------------------
MultipartUploadTestCase(
"Multipart upload successful: single part",
"Multipart upload successful: single part because of small file",
content_size=1024 * 1024, # less than part size
multipart_upload_part_size=10 * 1024 * 1024,
multipart_upload_min_stream_size=10 * 1024 * 1024,
source_type=[
UploadSourceType.FILE,
UploadSourceType.SEEKABLE_STREAM,
], # non-seekable streams always use multipart upload
expected_part_size=1024 * 1024, # chunk size is used
expected_single_shot_upload=True,
),
MultipartUploadTestCase(
"Multipart upload successful: empty file or empty seekable stream",
content_size=0, # content with zero length
multipart_upload_min_stream_size=100 * 1024 * 1024, # all files smaller than 100M goes to single-shot
expected_single_shot_upload=True,
expected_multipart_upload_aborted=None,
),
MultipartUploadTestCase(
"Multipart upload successful: multiple parts (aligned)",
Expand Down Expand Up @@ -2221,7 +2241,7 @@ def to_string(test_case: "MultipartUploadTestCase") -> str:
"Multipart parallel upload for stream: Upload errors are not retried",
content_size=10 * 1024 * 1024,
multipart_upload_part_size=1024 * 1024,
source_type=[UploadSourceType.STREAM],
source_type=[UploadSourceType.SEEKABLE_STREAM],
use_parallel=[True],
parallelism=1,
custom_response_on_upload=CustomResponse(
Expand Down Expand Up @@ -2385,7 +2405,8 @@ def __init__(
stream_size,
cloud,
overwrite,
source_type or [UploadSourceType.FILE, UploadSourceType.STREAM],
source_type
or [UploadSourceType.FILE, UploadSourceType.SEEKABLE_STREAM, UploadSourceType.NONSEEKABLE_STREAM],
use_parallel
or [True, False], # Resumable Upload doesn't support parallel uploading of parts, but fallback should work
parallelism,
Expand Down
Loading