Skip to content

Commit c17a5ae

Browse files
committed
Unify/simplify test
1 parent 978ba95 commit c17a5ae

File tree

1 file changed

+31
-69
lines changed

1 file changed

+31
-69
lines changed

tests/test_base_client.py

Lines changed: 31 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import random
33
from http.server import BaseHTTPRequestHandler
4-
from typing import Callable, Iterator, List, Optional
4+
from typing import Callable, Iterator, List, Optional, Tuple, Type
55
from unittest.mock import Mock
66

77
import pytest
@@ -360,12 +360,12 @@ def tell(self):
360360

361361
class RetryTestCase:
362362

363-
def __init__(self, data_provider: Callable, offset: Optional[int], expected_exception: bool,
363+
def __init__(self, data_provider: Callable, offset: Optional[int], expected_failure: bool,
364364
expected_result: bytes):
365365
self._data_provider = data_provider
366366
self._offset = offset
367367
self._expected_result = expected_result
368-
self._expected_exception = expected_exception
368+
self._expected_failure = expected_failure
369369

370370
def get_data(self):
371371
data = self._data_provider()
@@ -380,62 +380,6 @@ def create_non_seekable_stream(cls, data: bytes):
380380
return result
381381

382382

383-
retry_test_cases = [
384-
# bytes -> BytesIO
385-
RetryTestCase(lambda: b"0123456789", None, False, b"0123456789"),
386-
# str -> BytesIO
387-
RetryTestCase(lambda: "0123456789", None, False, b"0123456789"),
388-
# BytesIO directly
389-
RetryTestCase(lambda: io.BytesIO(b"0123456789"), None, False, b"0123456789"),
390-
# BytesIO directly with offset
391-
RetryTestCase(lambda: io.BytesIO(b"0123456789"), 4, False, b"456789"),
392-
# StringIO
393-
RetryTestCase(lambda: io.StringIO("0123456789"), None, False, b"0123456789"),
394-
# Non-seekable
395-
RetryTestCase(lambda: RetryTestCase.create_non_seekable_stream(b"0123456789"), None, True, b"0123456789")
396-
]
397-
398-
399-
@pytest.mark.parametrize('test_case', retry_test_cases)
400-
def test_rewind_seekable_stream_on_retryable_error_response(test_case: RetryTestCase):
401-
received_requests = []
402-
403-
data = test_case.get_data()
404-
405-
failure_count = 2
406-
407-
# Retry two times before succeeding.
408-
def inner(h: BaseHTTPRequestHandler):
409-
if len(received_requests) == failure_count:
410-
h.send_response(200)
411-
h.end_headers()
412-
else:
413-
h.send_response(429)
414-
h.send_header('Retry-After', '1') # avoid a warning in log
415-
h.end_headers()
416-
417-
content_length = int(h.headers.get('Content-Length', 0))
418-
if content_length > 0:
419-
received_requests.append(h.rfile.read(content_length))
420-
421-
with http_fixture_server(inner) as host:
422-
client = _BaseClient()
423-
424-
def do():
425-
# Retries should reset the stream.
426-
client.do('POST', f'{host}/foo', data=data)
427-
428-
if test_case._expected_exception:
429-
expected_attempts_made = 1
430-
with pytest.raises(DatabricksError):
431-
do()
432-
else:
433-
expected_attempts_made = failure_count + 1
434-
do()
435-
436-
assert received_requests == [test_case._expected_result for _ in range(expected_attempts_made)]
437-
438-
439383
class MockSession:
440384

441385
def __init__(self, failure_count: int, failure_provider: Callable[[], Response]):
@@ -444,7 +388,7 @@ def __init__(self, failure_count: int, failure_provider: Callable[[], Response])
444388
self._failure_provider = failure_provider
445389

446390
@classmethod
447-
def raise_retryable_exception(cls):
391+
def raise_timeout_exception(cls):
448392
raise Timeout("Fake timeout")
449393

450394
@classmethod
@@ -453,14 +397,16 @@ def return_retryable_response(cls):
453397
response = Response()
454398
response._content = b''
455399
response.status_code = 429
456-
response.reason = 'OK'
400+
response.headers = {'Retry-After': '1'}
401+
# response.reason = 'Too Many Requests'
457402
response.url = 'http://test.com/'
458403

459404
response.request = PreparedRequest()
460405
response.request.url = response.url
461406
response.request.method = 'POST'
462407
response.request.headers = None
463408
response.request.body = b''
409+
return response
464410

465411
# following the signature of Session.request()
466412
def request(self,
@@ -506,25 +452,41 @@ def request(self,
506452
return response
507453

508454

509-
@pytest.mark.parametrize('test_case', retry_test_cases)
510-
@pytest.mark.parametrize('failure_provider',
511-
[MockSession.raise_retryable_exception, MockSession.return_retryable_response])
512-
def test_rewind_seekable_stream(test_case: RetryTestCase, failure_provider: Callable[[], Response]):
455+
@pytest.mark.parametrize(
456+
'test_case',
457+
[
458+
# bytes -> BytesIO
459+
RetryTestCase(lambda: b"0123456789", None, False, b"0123456789"),
460+
# str -> BytesIO
461+
RetryTestCase(lambda: "0123456789", None, False, b"0123456789"),
462+
# BytesIO directly
463+
RetryTestCase(lambda: io.BytesIO(b"0123456789"), None, False, b"0123456789"),
464+
# BytesIO directly with offset
465+
RetryTestCase(lambda: io.BytesIO(b"0123456789"), 4, False, b"456789"),
466+
# StringIO
467+
RetryTestCase(lambda: io.StringIO("0123456789"), None, False, b"0123456789"),
468+
# Non-seekable
469+
RetryTestCase(lambda: RetryTestCase.create_non_seekable_stream(b"0123456789"), None, True,
470+
b"0123456789")
471+
])
472+
@pytest.mark.parametrize('failure', [[MockSession.raise_timeout_exception, Timeout],
473+
[MockSession.return_retryable_response, errors.TooManyRequests]])
474+
def test_rewind_seekable_stream(test_case: RetryTestCase, failure: Tuple[Callable[[], Response], Type]):
513475
failure_count = 2
514476

515477
data = test_case.get_data()
516478

517-
session = MockSession(failure_count, failure_provider)
479+
session = MockSession(failure_count, failure[0])
518480
client = _BaseClient()
519481
client._session = session
520482

521483
def do():
522-
# Retries should reset the stream.
523484
client.do('POST', f'test.com/foo', data=data)
524485

525-
if test_case._expected_exception:
486+
if test_case._expected_failure:
526487
expected_attempts_made = 1
527-
with pytest.raises(Timeout):
488+
exception_class = failure[1]
489+
with pytest.raises(exception_class):
528490
do()
529491
else:
530492
expected_attempts_made = failure_count + 1

0 commit comments

Comments
 (0)