Skip to content

Commit 3a443d9

Browse files
committed
Single-shot upload reports content length if known
1 parent e83da7c commit 3a443d9

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

databricks/sdk/service/files.py

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/test_files.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,8 +1279,6 @@ def run(self, config: Config):
12791279

12801280
session = requests.Session()
12811281
with requests_mock.Mocker(session=session) as session_mock:
1282-
session_mock.get(f"http://localhost/api/2.0/fs/files{MultipartUploadTestCase.path}", status_code=200)
1283-
12841282
upload_state = SingleShotUploadState()
12851283

12861284
def custom_matcher(request):
@@ -1855,3 +1853,128 @@ def to_string(test_case):
18551853
)
18561854
def test_resumable_upload(config: Config, test_case: ResumableUploadTestCase):
18571855
test_case.run(config)
1856+
1857+
1858+
class SingleShotUploadContentLengthTestCase:
1859+
1860+
def __init__(
1861+
self,
1862+
name: str,
1863+
contents: Callable[[], io.IOBase],
1864+
expected_content_length: Optional[int],
1865+
cleanup: Callable[[io.IOBase], None] = None,
1866+
):
1867+
super().__init__()
1868+
self.name = name
1869+
self.contents = contents
1870+
self.expected_content_length = expected_content_length
1871+
self.cleanup = cleanup
1872+
1873+
def __str__(self):
1874+
return self.name
1875+
1876+
@staticmethod
1877+
def to_string(test_case):
1878+
return str(test_case)
1879+
1880+
def run(self, config: Config):
1881+
config = config.copy()
1882+
config.enable_experimental_files_api_client = False # enforce single-shot upload
1883+
1884+
file_path = "/test.txt"
1885+
contents = self.contents()
1886+
1887+
try:
1888+
with requests_mock.Mocker() as session_mock:
1889+
1890+
def custom_matcher(request):
1891+
request_url = urlparse(request.url)
1892+
1893+
if (
1894+
request_url.hostname == "localhost"
1895+
and request_url.path == f"/api/2.0/fs/files{file_path}"
1896+
and request.method == "PUT"
1897+
):
1898+
body = request.body.read()
1899+
1900+
if self.expected_content_length:
1901+
content_length = request.headers["Content-Length"]
1902+
assert self.expected_content_length == int(content_length)
1903+
assert len(body) == int(content_length)
1904+
else:
1905+
assert request.headers.get("Content-Length") is None
1906+
1907+
resp = requests.Response()
1908+
resp.status_code = 204
1909+
resp.request = request
1910+
resp._content = b""
1911+
return resp
1912+
return None
1913+
1914+
session_mock.add_matcher(matcher=custom_matcher)
1915+
1916+
w = WorkspaceClient(config=config)
1917+
w.files.upload(file_path, contents)
1918+
finally:
1919+
if self.cleanup:
1920+
self.cleanup(contents)
1921+
1922+
1923+
def make_non_seekable(stream: io.IOBase, disable_seek: bool = False, disable_tell: bool = False):
1924+
def raise_(ex):
1925+
raise ex
1926+
1927+
stream.seekable = lambda: False # checked by BaseClient._is_seekable_stream()
1928+
1929+
# requests.super_len() does not check seekable(), it calls seek() and tell() directly
1930+
if disable_seek:
1931+
stream.seek = lambda offset, whence: raise_(OSError())
1932+
if disable_tell:
1933+
stream.tell = lambda: raise_(OSError())
1934+
return stream
1935+
1936+
1937+
def create_file_stream(length: int) -> io.IOBase:
1938+
fd, temp_file = mkstemp()
1939+
with open(fd, "wb") as f:
1940+
f.write(os.urandom(length))
1941+
1942+
stream = open(temp_file, "rb")
1943+
stream.delete_temp_file = lambda: os.remove(temp_file)
1944+
return stream
1945+
1946+
1947+
@pytest.mark.parametrize(
1948+
"test_case",
1949+
[
1950+
SingleShotUploadContentLengthTestCase(
1951+
"Empty contents treated as unknown length", contents=lambda: io.BytesIO(b""), expected_content_length=None
1952+
),
1953+
SingleShotUploadContentLengthTestCase("Bytes", contents=lambda: io.BytesIO(b"abc"), expected_content_length=3),
1954+
SingleShotUploadContentLengthTestCase(
1955+
"seek disabled: length unknown",
1956+
contents=lambda: make_non_seekable(io.BytesIO(b"abc"), disable_seek=True),
1957+
expected_content_length=None,
1958+
),
1959+
SingleShotUploadContentLengthTestCase(
1960+
"tell disabled: length unknown",
1961+
contents=lambda: make_non_seekable(io.BytesIO(b"abc"), disable_tell=True),
1962+
expected_content_length=None,
1963+
),
1964+
SingleShotUploadContentLengthTestCase(
1965+
"File stream: length reported",
1966+
contents=lambda: create_file_stream(566),
1967+
expected_content_length=566,
1968+
cleanup=lambda stream: stream.delete_temp_file(),
1969+
),
1970+
SingleShotUploadContentLengthTestCase(
1971+
"File stream with tell disabled: length unknown",
1972+
contents=lambda: make_non_seekable(create_file_stream(239), disable_tell=True),
1973+
expected_content_length=None,
1974+
cleanup=lambda stream: stream.delete_temp_file(),
1975+
),
1976+
],
1977+
ids=SingleShotUploadContentLengthTestCase.to_string,
1978+
)
1979+
def test_content_length(config: Config, test_case: SingleShotUploadContentLengthTestCase):
1980+
test_case.run(config)

0 commit comments

Comments
 (0)