Skip to content

Commit f508348

Browse files
committed
Add tests for retriable requests
1 parent 5339396 commit f508348

File tree

1 file changed

+129
-65
lines changed

1 file changed

+129
-65
lines changed

tests/test_base_client.py

Lines changed: 129 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import io
22
import random
33
from http.server import BaseHTTPRequestHandler
4-
from typing import Iterator, List
4+
from typing import Callable, Iterator, List, Optional
55
from unittest.mock import Mock
66

77
import pytest
8+
from requests import PreparedRequest, Response, Timeout
89

910
from databricks.sdk import errors, useragent
1011
from databricks.sdk._base_client import (_BaseClient, _RawResponse,
@@ -62,7 +63,7 @@ def test_streaming_response_read_closes(config):
6263
@pytest.mark.parametrize('status_code,headers,body,expected_error', [
6364
(400, {}, {
6465
"message":
65-
"errorMessage",
66+
"errorMessage",
6667
"details": [{
6768
"type": DatabricksError._error_info_type,
6869
"reason": "error reason",
@@ -103,9 +104,9 @@ def test_streaming_response_read_closes(config):
103104
(429, {
104105
'Retry-After': '100'
105106
}, {
106-
'error_code': 'TOO_MANY_REQUESTS',
107-
'message': 'errorMessage',
108-
}, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)),
107+
'error_code': 'TOO_MANY_REQUESTS',
108+
'message': 'errorMessage',
109+
}, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)),
109110
(503, {}, {
110111
'error_code': 'TEMPORARILY_UNAVAILABLE',
111112
'message': 'errorMessage',
@@ -114,9 +115,9 @@ def test_streaming_response_read_closes(config):
114115
(503, {
115116
'Retry-After': '100'
116117
}, {
117-
'error_code': 'TEMPORARILY_UNAVAILABLE',
118-
'message': 'errorMessage',
119-
},
118+
'error_code': 'TEMPORARILY_UNAVAILABLE',
119+
'message': 'errorMessage',
120+
},
120121
errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE',
121122
retry_after_secs=100)),
122123
(404, {}, {
@@ -357,91 +358,154 @@ def tell(self):
357358
assert client._is_seekable_stream(CustomSeekableStream())
358359

359360

360-
@pytest.mark.parametrize(
361-
'input_data',
362-
[
363-
b"0123456789", # bytes -> BytesIO
364-
"0123456789", # str -> BytesIO
365-
io.BytesIO(b"0123456789"), # BytesIO directly
366-
io.StringIO("0123456789"), # StringIO
367-
])
368-
def test_reset_seekable_stream_on_retry(input_data):
369-
received_data = []
361+
class RetryTestCase:
370362

371-
# Retry two times before succeeding.
372-
def inner(h: BaseHTTPRequestHandler):
373-
if len(received_data) == 2:
374-
h.send_response(200)
375-
h.end_headers()
376-
else:
377-
h.send_response(429)
378-
h.end_headers()
363+
def __init__(self, data_provider: Callable, offset: Optional[int], expected_exception: bool,
364+
expected_result: bytes):
365+
self._data_provider = data_provider
366+
self._offset = offset
367+
self._expected_result = expected_result
368+
self._expected_exception = expected_exception
379369

380-
content_length = int(h.headers.get('Content-Length', 0))
381-
if content_length > 0:
382-
received_data.append(h.rfile.read(content_length))
370+
def get_data(self):
371+
data = self._data_provider()
372+
if self._offset is not None:
373+
data.seek(self._offset)
374+
return data
383375

384-
with http_fixture_server(inner) as host:
385-
client = _BaseClient()
376+
@classmethod
377+
def create_non_seekable_stream(cls, data: bytes):
378+
result = io.BytesIO(data)
379+
result.seekable = lambda: False # makes the stream appear non-seekable
380+
return result
381+
382+
383+
retry_test_cases = [
384+
# bytes -> BytesIO
385+
RetryTestCase(lambda: b"0123456789", None, False, b"0123456789"),
386+
# str -> BytesIO
387+
RetryTestCase(lambda: "0123456789", None, False, b"0123456789"),
388+
# BytesIO directly
389+
RetryTestCase(lambda: io.BytesIO(b"0123456789"), None, False, b"0123456789"),
390+
# BytesIO directly with offset
391+
RetryTestCase(lambda: io.BytesIO(b"0123456789"), 4, False, b"456789"),
392+
# StringIO
393+
RetryTestCase(lambda: io.StringIO("0123456789"), None, False, b"0123456789"),
394+
# Non-seekable
395+
RetryTestCase(lambda: RetryTestCase.create_non_seekable_stream(b"0123456789"),
396+
None, True, b"0123456789")
397+
]
386398

387-
# Retries should reset the stream.
388-
client.do('POST', f'{host}/foo', data=input_data)
389399

390-
assert received_data == [b"0123456789", b"0123456789", b"0123456789"]
400+
@pytest.mark.parametrize('test_case', retry_test_cases)
401+
def test_rewind_seekable_stream_on_retryable_error_response(test_case: RetryTestCase):
402+
received_requests = []
391403

404+
data = test_case.get_data()
392405

393-
def test_reset_seekable_stream_to_their_initial_position_on_retry():
394-
received_data = []
406+
failure_count = 2
395407

396408
# Retry two times before succeeding.
397409
def inner(h: BaseHTTPRequestHandler):
398-
if len(received_data) == 2:
410+
if len(received_requests) == failure_count:
399411
h.send_response(200)
400412
h.end_headers()
401413
else:
402414
h.send_response(429)
415+
h.send_header('Retry-After', '1') # avoid a warning in log
403416
h.end_headers()
404417

405418
content_length = int(h.headers.get('Content-Length', 0))
406419
if content_length > 0:
407-
received_data.append(h.rfile.read(content_length))
408-
409-
input_data = io.BytesIO(b"0123456789")
410-
input_data.seek(4)
420+
received_requests.append(h.rfile.read(content_length))
411421

412422
with http_fixture_server(inner) as host:
413423
client = _BaseClient()
414424

415-
# Retries should reset the stream.
416-
client.do('POST', f'{host}/foo', data=input_data)
425+
def do():
426+
# Retries should reset the stream.
427+
client.do('POST', f'{host}/foo', data=data)
417428

418-
assert received_data == [b"456789", b"456789", b"456789"]
419-
assert input_data.tell() == 10 # EOF
429+
if test_case._expected_exception:
430+
expected_attempts_made = 1
431+
with pytest.raises(DatabricksError):
432+
do()
433+
else:
434+
expected_attempts_made = failure_count + 1
435+
do()
436+
437+
assert received_requests == [test_case._expected_result for _ in range(expected_attempts_made)]
438+
439+
440+
class MockSession:
441+
442+
def __init__(self, failure_count: int):
443+
self._failure_count = failure_count
444+
self._received_requests: List[bytes] = []
445+
446+
# following the signature of Session.request()
447+
def request(self,
448+
method,
449+
url,
450+
params=None,
451+
data=None,
452+
headers=None,
453+
cookies=None,
454+
files=None,
455+
auth=None,
456+
timeout=None,
457+
allow_redirects=True,
458+
proxies=None,
459+
hooks=None,
460+
stream=None,
461+
verify=None,
462+
cert=None,
463+
json=None):
464+
request_body = data.read()
465+
466+
if isinstance(request_body, str):
467+
request_body = request_body.encode('utf-8') # to be able to compare with expected bytes
468+
469+
self._received_requests.append(request_body)
470+
if self._failure_count > 0:
471+
self._failure_count -= 1
472+
raise Timeout("Fake timeout") # retryable error
473+
else:
474+
# fill response fields so that logging does not fail
475+
response = Response()
476+
response._content = b''
477+
response.status_code = 200
478+
response.reason = 'OK'
479+
response.url = url
420480

481+
response.request = PreparedRequest()
482+
response.request.url = url
483+
response.request.method = method
484+
response.request.headers = headers
485+
response.request.body = data
486+
return response
421487

422-
def test_no_retry_or_reset_on_non_seekable_stream():
423-
requests = []
424488

425-
# Always respond with a response that triggers a retry.
426-
def inner(h: BaseHTTPRequestHandler):
427-
content_length = int(h.headers.get('Content-Length', 0))
428-
if content_length > 0:
429-
requests.append(h.rfile.read(content_length))
489+
@pytest.mark.parametrize('test_case', retry_test_cases)
490+
def test_rewind_seekable_stream_on_retryable_exception(test_case: RetryTestCase):
491+
failure_count = 2
430492

431-
h.send_response(429)
432-
h.send_header('Retry-After', '1')
433-
h.end_headers()
493+
data = test_case.get_data()
434494

435-
input_data = io.BytesIO(b"0123456789")
436-
input_data.seekable = lambda: False # makes the stream appear non-seekable
495+
session = MockSession(failure_count)
496+
client = _BaseClient()
497+
client._session = session
437498

438-
with http_fixture_server(inner) as host:
439-
client = _BaseClient()
499+
def do():
500+
# Retries should reset the stream.
501+
client.do('POST', f'test.com/foo', data=data)
440502

441-
# Should raise error immediately without retry.
442-
with pytest.raises(DatabricksError):
443-
client.do('POST', f'{host}/foo', data=input_data)
503+
if test_case._expected_exception:
504+
expected_attempts_made = 1
505+
with pytest.raises(Timeout):
506+
do()
507+
else:
508+
expected_attempts_made = failure_count + 1
509+
do()
444510

445-
# Verify that only one request was made (no retries).
446-
assert requests == [b"0123456789"]
447-
assert input_data.tell() == 10 # EOF
511+
assert session._received_requests == [test_case._expected_result for _ in range(expected_attempts_made)]

0 commit comments

Comments
 (0)