Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 129 additions & 65 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import io
import random
from http.server import BaseHTTPRequestHandler
from typing import Iterator, List
from typing import Callable, Iterator, List, Optional
from unittest.mock import Mock

import pytest
from requests import PreparedRequest, Response, Timeout

from databricks.sdk import errors, useragent
from databricks.sdk._base_client import (_BaseClient, _RawResponse,
Expand Down Expand Up @@ -62,7 +63,7 @@ def test_streaming_response_read_closes(config):
@pytest.mark.parametrize('status_code,headers,body,expected_error', [
(400, {}, {
"message":
"errorMessage",
"errorMessage",
"details": [{
"type": DatabricksError._error_info_type,
"reason": "error reason",
Expand Down Expand Up @@ -103,9 +104,9 @@ def test_streaming_response_read_closes(config):
(429, {
'Retry-After': '100'
}, {
'error_code': 'TOO_MANY_REQUESTS',
'message': 'errorMessage',
}, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)),
'error_code': 'TOO_MANY_REQUESTS',
'message': 'errorMessage',
}, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)),
(503, {}, {
'error_code': 'TEMPORARILY_UNAVAILABLE',
'message': 'errorMessage',
Expand All @@ -114,9 +115,9 @@ def test_streaming_response_read_closes(config):
(503, {
'Retry-After': '100'
}, {
'error_code': 'TEMPORARILY_UNAVAILABLE',
'message': 'errorMessage',
},
'error_code': 'TEMPORARILY_UNAVAILABLE',
'message': 'errorMessage',
},
errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE',
retry_after_secs=100)),
(404, {}, {
Expand Down Expand Up @@ -357,91 +358,154 @@ def tell(self):
assert client._is_seekable_stream(CustomSeekableStream())


@pytest.mark.parametrize(
'input_data',
[
b"0123456789", # bytes -> BytesIO
"0123456789", # str -> BytesIO
io.BytesIO(b"0123456789"), # BytesIO directly
io.StringIO("0123456789"), # StringIO
])
def test_reset_seekable_stream_on_retry(input_data):
received_data = []
class RetryTestCase:

# Retry two times before succeeding.
def inner(h: BaseHTTPRequestHandler):
if len(received_data) == 2:
h.send_response(200)
h.end_headers()
else:
h.send_response(429)
h.end_headers()
def __init__(self, data_provider: Callable, offset: Optional[int], expected_exception: bool,
expected_result: bytes):
self._data_provider = data_provider
self._offset = offset
self._expected_result = expected_result
self._expected_exception = expected_exception

content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))
def get_data(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test cases are reused, so we need to construct a fresh data object every time.

data = self._data_provider()
if self._offset is not None:
data.seek(self._offset)
return data

with http_fixture_server(inner) as host:
client = _BaseClient()
@classmethod
def create_non_seekable_stream(cls, data: bytes):
result = io.BytesIO(data)
result.seekable = lambda: False # makes the stream appear non-seekable
return result


retry_test_cases = [
# bytes -> BytesIO
RetryTestCase(lambda: b"0123456789", None, False, b"0123456789"),
# str -> BytesIO
RetryTestCase(lambda: "0123456789", None, False, b"0123456789"),
# BytesIO directly
RetryTestCase(lambda: io.BytesIO(b"0123456789"), None, False, b"0123456789"),
# BytesIO directly with offset
RetryTestCase(lambda: io.BytesIO(b"0123456789"), 4, False, b"456789"),
# StringIO
RetryTestCase(lambda: io.StringIO("0123456789"), None, False, b"0123456789"),
# Non-seekable
RetryTestCase(lambda: RetryTestCase.create_non_seekable_stream(b"0123456789"),
None, True, b"0123456789")
]

# Retries should reset the stream.
client.do('POST', f'{host}/foo', data=input_data)

assert received_data == [b"0123456789", b"0123456789", b"0123456789"]
@pytest.mark.parametrize('test_case', retry_test_cases)
def test_rewind_seekable_stream_on_retryable_error_response(test_case: RetryTestCase):
received_requests = []

data = test_case.get_data()

def test_reset_seekable_stream_to_their_initial_position_on_retry():
received_data = []
failure_count = 2

# Retry two times before succeeding.
def inner(h: BaseHTTPRequestHandler):
if len(received_data) == 2:
if len(received_requests) == failure_count:
h.send_response(200)
h.end_headers()
else:
h.send_response(429)
h.send_header('Retry-After', '1') # avoid a warning in log
h.end_headers()

content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

input_data = io.BytesIO(b"0123456789")
input_data.seek(4)
received_requests.append(h.rfile.read(content_length))

with http_fixture_server(inner) as host:
client = _BaseClient()

# Retries should reset the stream.
client.do('POST', f'{host}/foo', data=input_data)
def do():
# Retries should reset the stream.
client.do('POST', f'{host}/foo', data=data)

assert received_data == [b"456789", b"456789", b"456789"]
assert input_data.tell() == 10 # EOF
if test_case._expected_exception:
expected_attempts_made = 1
with pytest.raises(DatabricksError):
do()
else:
expected_attempts_made = failure_count + 1
do()

assert received_requests == [test_case._expected_result for _ in range(expected_attempts_made)]


class MockSession:
Copy link
Contributor Author

@ksafonov-db ksafonov-db Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mock session reads all the input stream before failing.


def __init__(self, failure_count: int):
self._failure_count = failure_count
self._received_requests: List[bytes] = []

# following the signature of Session.request()
def request(self,
method,
url,
params=None,
data=None,
headers=None,
cookies=None,
files=None,
auth=None,
timeout=None,
allow_redirects=True,
proxies=None,
hooks=None,
stream=None,
verify=None,
cert=None,
json=None):
request_body = data.read()

if isinstance(request_body, str):
request_body = request_body.encode('utf-8') # to be able to compare with expected bytes

self._received_requests.append(request_body)
if self._failure_count > 0:
self._failure_count -= 1
raise Timeout("Fake timeout") # retryable error
else:
# fill response fields so that logging does not fail
response = Response()
response._content = b''
response.status_code = 200
response.reason = 'OK'
response.url = url

response.request = PreparedRequest()
response.request.url = url
response.request.method = method
response.request.headers = headers
response.request.body = data
return response

def test_no_retry_or_reset_on_non_seekable_stream():
requests = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
requests.append(h.rfile.read(content_length))
@pytest.mark.parametrize('test_case', retry_test_cases)
def test_rewind_seekable_stream_on_retryable_exception(test_case: RetryTestCase):
failure_count = 2

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()
data = test_case.get_data()

input_data = io.BytesIO(b"0123456789")
input_data.seekable = lambda: False # makes the stream appear non-seekable
session = MockSession(failure_count)
client = _BaseClient()
client._session = session

with http_fixture_server(inner) as host:
client = _BaseClient()
def do():
# Retries should reset the stream.
client.do('POST', f'test.com/foo', data=data)

# Should raise error immediately without retry.
with pytest.raises(DatabricksError):
client.do('POST', f'{host}/foo', data=input_data)
if test_case._expected_exception:
expected_attempts_made = 1
with pytest.raises(Timeout):
do()
else:
expected_attempts_made = failure_count + 1
do()

# Verify that only one request was made (no retries).
assert requests == [b"0123456789"]
assert input_data.tell() == 10 # EOF
assert session._received_requests == [test_case._expected_result for _ in range(expected_attempts_made)]
Loading