1313from enum import Enum
1414from tempfile import NamedTemporaryFile
1515from threading import Lock
16- from typing import Any , Callable , Dict , List , Optional , Type , Union
16+ from typing import Any , Callable , List , Optional , Type , Union , Dict
1717from urllib .parse import parse_qs , urlparse
1818
1919import pytest
2424from databricks .sdk import WorkspaceClient
2525from databricks .sdk .core import Config
2626from databricks .sdk .environments import Cloud , DatabricksEnvironment
27- from databricks .sdk .errors .platform import (AlreadyExists , BadRequest ,
28- InternalError , NotImplemented ,
29- PermissionDenied , TooManyRequests )
27+ from databricks .sdk .errors .platform import (
28+ AlreadyExists ,
29+ BadRequest ,
30+ InternalError ,
31+ NotImplemented ,
32+ PermissionDenied ,
33+ TooManyRequests ,
34+ )
3035from databricks .sdk .mixins .files import FallbackToDownloadUsingFilesApi
3136from databricks .sdk .mixins .files_utils import CreateDownloadUrlResponse
3237from tests .clock import FakeClock
3338
34- from .test_files_utils import Utils
39+ from .test_files_utils import Utils , NonSeekableBuffer
3540
3641logger = logging .getLogger (__name__ )
3742
@@ -1257,8 +1262,8 @@ def __init__(
12571262 custom_response_on_single_shot_upload : CustomResponse ,
12581263 # exception which is expected to be thrown (so upload is expected to have failed)
12591264 expected_exception_type : Optional [Type [BaseException ]],
1260- # if abort is expected to be called for multipart/resumable upload
1261- expected_multipart_upload_aborted : bool ,
1265+ # Whether abort is expected to be called for multipart/resumable upload, set to None if we don't care.
1266+ expected_multipart_upload_aborted : Optional [ bool ] ,
12621267 expected_single_shot_upload : bool ,
12631268 ):
12641269 self .name = name
@@ -1274,7 +1279,7 @@ def __init__(
12741279 self .multipart_upload_max_retries = multipart_upload_max_retries
12751280 self .custom_response_on_single_shot_upload = custom_response_on_single_shot_upload
12761281 self .expected_exception_type = expected_exception_type
1277- self .expected_multipart_upload_aborted : bool = expected_multipart_upload_aborted
1282+ self .expected_multipart_upload_aborted : Optional [ bool ] = expected_multipart_upload_aborted
12781283 self .expected_single_shot_upload = expected_single_shot_upload
12791284
12801285 self .path = "/test.txt"
@@ -1294,16 +1299,18 @@ def clear_state(self) -> None:
12941299 logger .warning ("Failed to remove temp file: %s" , file_path )
12951300 self .created_temp_files = []
12961301
1297- def get_upload_file (self , content : bytes , source_type : "UploadSourceType" ) -> Union [str , io .BytesIO ]:
1302+ def get_upload_file (self , content : bytes , source_type : "UploadSourceType" ) -> Union [str , io .BytesIO , NonSeekableBuffer ]:
12981303 """Returns a file or stream to upload based on the source type."""
12991304 if source_type == UploadSourceType .FILE :
13001305 with NamedTemporaryFile (mode = "wb" , delete = False ) as f :
13011306 f .write (content )
13021307 file_path = f .name
13031308 self .created_temp_files .append (file_path )
13041309 return file_path
1305- elif source_type == UploadSourceType .STREAM :
1310+ elif source_type == UploadSourceType .SEEKABLE_STREAM :
13061311 return io .BytesIO (content )
1312+ elif source_type == UploadSourceType .NONSEEKABLE_STREAM :
1313+ return NonSeekableBuffer (content )
13071314 else :
13081315 raise ValueError (f"Unknown source type: { source_type } " )
13091316
@@ -1417,7 +1424,7 @@ def upload() -> None:
14171424 ), "Single-shot upload should not have succeeded"
14181425
14191426 assert (
1420- multipart_server_state .aborted == self .expected_multipart_upload_aborted
1427+ self . expected_multipart_upload_aborted is None or multipart_server_state .aborted == self .expected_multipart_upload_aborted
14211428 ), "Multipart upload aborted state mismatch"
14221429
14231430 finally :
@@ -1433,7 +1440,8 @@ class UploadSourceType(Enum):
14331440 """Source type for the upload. Used to determine how to upload the file."""
14341441
14351442 FILE = "file" # upload from a file on disk
1436- STREAM = "stream" # upload from a stream (e.g. BytesIO)
1443+ SEEKABLE_STREAM = "seekable_stream" # upload from a seekable stream (e.g. BytesIO)
1444+ NONSEEKABLE_STREAM = "nonseekable_stream" # upload from a non-seekable stream (e.g. network stream)
14371445
14381446
14391447class MultipartUploadTestCase (UploadTestCase ):
@@ -1522,15 +1530,15 @@ def __init__(
15221530 # if abort is expected to be called
15231531 # expected part size
15241532 expected_part_size : Optional [int ] = None ,
1525- expected_multipart_upload_aborted : bool = False ,
1533+ expected_multipart_upload_aborted : Optional [ bool ] = False ,
15261534 expected_single_shot_upload : bool = False ,
15271535 ):
15281536 super ().__init__ (
15291537 name ,
15301538 content_size ,
15311539 cloud ,
15321540 overwrite ,
1533- source_type or [UploadSourceType .FILE , UploadSourceType .STREAM ],
1541+ source_type or [UploadSourceType .FILE , UploadSourceType .SEEKABLE_STREAM , UploadSourceType . NONSEEKABLE_STREAM ],
15341542 use_parallel or [False , True ],
15351543 parallelism ,
15361544 multipart_upload_min_stream_size ,
@@ -1662,6 +1670,7 @@ def processor() -> list:
16621670 request_json = request .json ()
16631671 etags = {}
16641672
1673+ assert len (request_json ["parts" ]) > 0
16651674 for part in request_json ["parts" ]:
16661675 etags [part ["part_number" ]] = part ["etag" ]
16671676
@@ -1738,10 +1747,19 @@ def to_string(test_case: "MultipartUploadTestCase") -> str:
17381747 [
17391748 # -------------------------- happy cases --------------------------
17401749 MultipartUploadTestCase (
1741- "Multipart upload successful: single part" ,
1750+ "Multipart upload successful: single part because of small file " ,
17421751 content_size = 1024 * 1024 , # less than part size
1743- multipart_upload_part_size = 10 * 1024 * 1024 ,
1752+ multipart_upload_min_stream_size = 10 * 1024 * 1024 ,
1753+ source_type = [UploadSourceType .FILE , UploadSourceType .SEEKABLE_STREAM ], # non-seekable streams always use multipart upload
17441754 expected_part_size = 1024 * 1024 , # chunk size is used
1755+ expected_single_shot_upload = True ,
1756+ ),
1757+ MultipartUploadTestCase (
1758+ "Multipart upload successful: empty file or empty seekable stream" ,
1759+ content_size = 0 , # less than part size
1760+ multipart_upload_min_stream_size = 100 * 1024 * 1024 , # all files smaller than 100M goes to single-shot
1761+ expected_single_shot_upload = True ,
1762+ expected_multipart_upload_aborted = None
17451763 ),
17461764 MultipartUploadTestCase (
17471765 "Multipart upload successful: multiple parts (aligned)" ,
@@ -2161,7 +2179,7 @@ def to_string(test_case: "MultipartUploadTestCase") -> str:
21612179 "Multipart parallel upload for stream: Upload errors are not retried" ,
21622180 content_size = 10 * 1024 * 1024 ,
21632181 multipart_upload_part_size = 1024 * 1024 ,
2164- source_type = [UploadSourceType .STREAM ],
2182+ source_type = [UploadSourceType .SEEKABLE_STREAM ],
21652183 use_parallel = [True ],
21662184 parallelism = 1 ,
21672185 custom_response_on_upload = CustomResponse (
@@ -2325,7 +2343,7 @@ def __init__(
23252343 stream_size ,
23262344 cloud ,
23272345 overwrite ,
2328- source_type or [UploadSourceType .FILE , UploadSourceType .STREAM ],
2346+ source_type or [UploadSourceType .FILE , UploadSourceType .SEEKABLE_STREAM , UploadSourceType . NONSEEKABLE_STREAM ],
23292347 use_parallel
23302348 or [True , False ], # Resumable Upload doesn't support parallel uploading of parts, but fallback should work
23312349 parallelism ,
0 commit comments