Skip to content

Commit bf741e8

Browse files
Fix FilesExt upload fails when content size is zero
1 parent 49eb17b commit bf741e8

File tree

3 files changed

+49
-26
lines changed

3 files changed

+49
-26
lines changed

NEXT_CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
### Bug Fixes
88

9+
- Fix `FilesExt.upload` and `FilesExt.upload_from` would fail when the source content is empty and `use_parallel=True`.
10+
911
### Documentation
1012

1113
### Internal Changes

databricks/sdk/mixins/files.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
from tempfile import mkstemp
2323
from threading import Event, Thread
2424
from types import TracebackType
25-
from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Callable, Generator,
26-
Iterable, Optional, Type, Union)
25+
from typing import TYPE_CHECKING, AnyStr, BinaryIO, Callable, Generator, Iterable, Optional, Type, Union
2726
from urllib import parse
2827

2928
import requests
@@ -39,8 +38,7 @@
3938
from ..service import files
4039
from ..service._internal import _escape_multi_segment_path_parameter
4140
from ..service.files import DownloadResponse
42-
from .files_utils import (CreateDownloadUrlResponse, _ConcatenatedInputStream,
43-
_PresignedUrlDistributor)
41+
from .files_utils import CreateDownloadUrlResponse, _ConcatenatedInputStream, _PresignedUrlDistributor
4442

4543
if TYPE_CHECKING:
4644
from _typeshed import Self
@@ -1134,7 +1132,10 @@ def upload(
11341132
f"Upload context: part_size={ctx.part_size}, batch_size={ctx.batch_size}, content_length={ctx.content_length}"
11351133
)
11361134

1137-
if ctx.use_parallel:
1135+
if ctx.use_parallel and (
1136+
ctx.content_length is None or ctx.content_length >= self._config.files_ext_multipart_upload_min_stream_size
1137+
):
1138+
# if ctx.use_parallel:
11381139
self._parallel_upload_from_stream(ctx, content)
11391140
return UploadStreamResult()
11401141
elif ctx.content_length is not None:
@@ -1206,7 +1207,8 @@ def upload_from(
12061207
use_parallel=use_parallel,
12071208
parallelism=parallelism,
12081209
)
1209-
if ctx.use_parallel:
1210+
if ctx.use_parallel and ctx.content_length >= self._config.files_ext_multipart_upload_min_stream_size:
1211+
# if ctx.use_parallel:
12101212
self._parallel_upload_from_file(ctx)
12111213
return UploadFileResult()
12121214
else:
@@ -1459,8 +1461,9 @@ def _parallel_multipart_upload_from_stream(
14591461
# Do the first part read ahead
14601462
pre_read_buffer = content.read(ctx.part_size)
14611463
if not pre_read_buffer:
1462-
self._complete_multipart_upload(ctx, {}, session_token)
1463-
return
1464+
raise FallbackToUploadUsingFilesApi(
1465+
b"", "Falling back to single-shot upload with Files API due to empty input stream"
1466+
)
14641467
try:
14651468
etag = self._do_upload_one_part(
14661469
ctx, cloud_provider_session, 1, 0, len(pre_read_buffer), session_token, BytesIO(pre_read_buffer)

tests/test_files.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from enum import Enum
1414
from tempfile import NamedTemporaryFile
1515
from threading import Lock
16-
from typing import Any, Callable, Dict, List, Optional, Type, Union
16+
from typing import Any, Callable, List, Optional, Type, Union, Dict
1717
from urllib.parse import parse_qs, urlparse
1818

1919
import pytest
@@ -24,14 +24,19 @@
2424
from databricks.sdk import WorkspaceClient
2525
from databricks.sdk.core import Config
2626
from databricks.sdk.environments import Cloud, DatabricksEnvironment
27-
from databricks.sdk.errors.platform import (AlreadyExists, BadRequest,
28-
InternalError, NotImplemented,
29-
PermissionDenied, TooManyRequests)
27+
from databricks.sdk.errors.platform import (
28+
AlreadyExists,
29+
BadRequest,
30+
InternalError,
31+
NotImplemented,
32+
PermissionDenied,
33+
TooManyRequests,
34+
)
3035
from databricks.sdk.mixins.files import FallbackToDownloadUsingFilesApi
3136
from databricks.sdk.mixins.files_utils import CreateDownloadUrlResponse
3237
from tests.clock import FakeClock
3338

34-
from .test_files_utils import Utils
39+
from .test_files_utils import Utils, NonSeekableBuffer
3540

3641
logger = logging.getLogger(__name__)
3742

@@ -1257,8 +1262,8 @@ def __init__(
12571262
custom_response_on_single_shot_upload: CustomResponse,
12581263
# exception which is expected to be thrown (so upload is expected to have failed)
12591264
expected_exception_type: Optional[Type[BaseException]],
1260-
# if abort is expected to be called for multipart/resumable upload
1261-
expected_multipart_upload_aborted: bool,
1265+
# Whether abort is expected to be called for multipart/resumable upload, set to None if we don't care.
1266+
expected_multipart_upload_aborted: Optional[bool],
12621267
expected_single_shot_upload: bool,
12631268
):
12641269
self.name = name
@@ -1274,7 +1279,7 @@ def __init__(
12741279
self.multipart_upload_max_retries = multipart_upload_max_retries
12751280
self.custom_response_on_single_shot_upload = custom_response_on_single_shot_upload
12761281
self.expected_exception_type = expected_exception_type
1277-
self.expected_multipart_upload_aborted: bool = expected_multipart_upload_aborted
1282+
self.expected_multipart_upload_aborted: Optional[bool] = expected_multipart_upload_aborted
12781283
self.expected_single_shot_upload = expected_single_shot_upload
12791284

12801285
self.path = "/test.txt"
@@ -1294,16 +1299,18 @@ def clear_state(self) -> None:
12941299
logger.warning("Failed to remove temp file: %s", file_path)
12951300
self.created_temp_files = []
12961301

1297-
def get_upload_file(self, content: bytes, source_type: "UploadSourceType") -> Union[str, io.BytesIO]:
1302+
def get_upload_file(self, content: bytes, source_type: "UploadSourceType") -> Union[str, io.BytesIO, NonSeekableBuffer]:
12981303
"""Returns a file or stream to upload based on the source type."""
12991304
if source_type == UploadSourceType.FILE:
13001305
with NamedTemporaryFile(mode="wb", delete=False) as f:
13011306
f.write(content)
13021307
file_path = f.name
13031308
self.created_temp_files.append(file_path)
13041309
return file_path
1305-
elif source_type == UploadSourceType.STREAM:
1310+
elif source_type == UploadSourceType.SEEKABLE_STREAM:
13061311
return io.BytesIO(content)
1312+
elif source_type == UploadSourceType.NONSEEKABLE_STREAM:
1313+
return NonSeekableBuffer(content)
13071314
else:
13081315
raise ValueError(f"Unknown source type: {source_type}")
13091316

@@ -1417,7 +1424,7 @@ def upload() -> None:
14171424
), "Single-shot upload should not have succeeded"
14181425

14191426
assert (
1420-
multipart_server_state.aborted == self.expected_multipart_upload_aborted
1427+
self.expected_multipart_upload_aborted is None or multipart_server_state.aborted == self.expected_multipart_upload_aborted
14211428
), "Multipart upload aborted state mismatch"
14221429

14231430
finally:
@@ -1433,7 +1440,8 @@ class UploadSourceType(Enum):
14331440
"""Source type for the upload. Used to determine how to upload the file."""
14341441

14351442
FILE = "file" # upload from a file on disk
1436-
STREAM = "stream" # upload from a stream (e.g. BytesIO)
1443+
SEEKABLE_STREAM = "seekable_stream" # upload from a seekable stream (e.g. BytesIO)
1444+
NONSEEKABLE_STREAM = "nonseekable_stream" # upload from a non-seekable stream (e.g. network stream)
14371445

14381446

14391447
class MultipartUploadTestCase(UploadTestCase):
@@ -1522,15 +1530,15 @@ def __init__(
15221530
# if abort is expected to be called
15231531
# expected part size
15241532
expected_part_size: Optional[int] = None,
1525-
expected_multipart_upload_aborted: bool = False,
1533+
expected_multipart_upload_aborted: Optional[bool] = False,
15261534
expected_single_shot_upload: bool = False,
15271535
):
15281536
super().__init__(
15291537
name,
15301538
content_size,
15311539
cloud,
15321540
overwrite,
1533-
source_type or [UploadSourceType.FILE, UploadSourceType.STREAM],
1541+
source_type or [UploadSourceType.FILE, UploadSourceType.SEEKABLE_STREAM, UploadSourceType.NONSEEKABLE_STREAM],
15341542
use_parallel or [False, True],
15351543
parallelism,
15361544
multipart_upload_min_stream_size,
@@ -1662,6 +1670,7 @@ def processor() -> list:
16621670
request_json = request.json()
16631671
etags = {}
16641672

1673+
assert len(request_json["parts"]) > 0
16651674
for part in request_json["parts"]:
16661675
etags[part["part_number"]] = part["etag"]
16671676

@@ -1738,10 +1747,19 @@ def to_string(test_case: "MultipartUploadTestCase") -> str:
17381747
[
17391748
# -------------------------- happy cases --------------------------
17401749
MultipartUploadTestCase(
1741-
"Multipart upload successful: single part",
1750+
"Multipart upload successful: single part because of small file",
17421751
content_size=1024 * 1024, # less than part size
1743-
multipart_upload_part_size=10 * 1024 * 1024,
1752+
multipart_upload_min_stream_size=10 * 1024 * 1024,
1753+
source_type=[UploadSourceType.FILE, UploadSourceType.SEEKABLE_STREAM], # non-seekable streams always use multipart upload
17441754
expected_part_size=1024 * 1024, # chunk size is used
1755+
expected_single_shot_upload=True,
1756+
),
1757+
MultipartUploadTestCase(
1758+
"Multipart upload successful: empty file or empty seekable stream",
1759+
content_size=0, # less than part size
1760+
multipart_upload_min_stream_size=100 * 1024 * 1024, # all files smaller than 100M goes to single-shot
1761+
expected_single_shot_upload=True,
1762+
expected_multipart_upload_aborted=None
17451763
),
17461764
MultipartUploadTestCase(
17471765
"Multipart upload successful: multiple parts (aligned)",
@@ -2161,7 +2179,7 @@ def to_string(test_case: "MultipartUploadTestCase") -> str:
21612179
"Multipart parallel upload for stream: Upload errors are not retried",
21622180
content_size=10 * 1024 * 1024,
21632181
multipart_upload_part_size=1024 * 1024,
2164-
source_type=[UploadSourceType.STREAM],
2182+
source_type=[UploadSourceType.SEEKABLE_STREAM],
21652183
use_parallel=[True],
21662184
parallelism=1,
21672185
custom_response_on_upload=CustomResponse(
@@ -2325,7 +2343,7 @@ def __init__(
23252343
stream_size,
23262344
cloud,
23272345
overwrite,
2328-
source_type or [UploadSourceType.FILE, UploadSourceType.STREAM],
2346+
source_type or [UploadSourceType.FILE, UploadSourceType.SEEKABLE_STREAM, UploadSourceType.NONSEEKABLE_STREAM],
23292347
use_parallel
23302348
or [True, False], # Resumable Upload doesn't support parallel uploading of parts, but fallback should work
23312349
parallelism,

0 commit comments

Comments
 (0)