11import io
22import random
33from http .server import BaseHTTPRequestHandler
4- from typing import Callable , Iterator , List , Optional
4+ from typing import Callable , Iterator , List , Optional , Tuple , Type
55from unittest .mock import Mock
66
77import pytest
@@ -360,12 +360,12 @@ def tell(self):
360360
361361class RetryTestCase :
362362
363- def __init__ (self , data_provider : Callable , offset : Optional [int ], expected_exception : bool ,
363+ def __init__ (self , data_provider : Callable , offset : Optional [int ], expected_failure : bool ,
364364 expected_result : bytes ):
365365 self ._data_provider = data_provider
366366 self ._offset = offset
367367 self ._expected_result = expected_result
368- self ._expected_exception = expected_exception
368+ self ._expected_failure = expected_failure
369369
370370 def get_data (self ):
371371 data = self ._data_provider ()
@@ -380,62 +380,6 @@ def create_non_seekable_stream(cls, data: bytes):
380380 return result
381381
382382
383- retry_test_cases = [
384- # bytes -> BytesIO
385- RetryTestCase (lambda : b"0123456789" , None , False , b"0123456789" ),
386- # str -> BytesIO
387- RetryTestCase (lambda : "0123456789" , None , False , b"0123456789" ),
388- # BytesIO directly
389- RetryTestCase (lambda : io .BytesIO (b"0123456789" ), None , False , b"0123456789" ),
390- # BytesIO directly with offset
391- RetryTestCase (lambda : io .BytesIO (b"0123456789" ), 4 , False , b"456789" ),
392- # StringIO
393- RetryTestCase (lambda : io .StringIO ("0123456789" ), None , False , b"0123456789" ),
394- # Non-seekable
395- RetryTestCase (lambda : RetryTestCase .create_non_seekable_stream (b"0123456789" ), None , True , b"0123456789" )
396- ]
397-
398-
399- @pytest .mark .parametrize ('test_case' , retry_test_cases )
400- def test_rewind_seekable_stream_on_retryable_error_response (test_case : RetryTestCase ):
401- received_requests = []
402-
403- data = test_case .get_data ()
404-
405- failure_count = 2
406-
407- # Retry two times before succeeding.
408- def inner (h : BaseHTTPRequestHandler ):
409- if len (received_requests ) == failure_count :
410- h .send_response (200 )
411- h .end_headers ()
412- else :
413- h .send_response (429 )
414- h .send_header ('Retry-After' , '1' ) # avoid a warning in log
415- h .end_headers ()
416-
417- content_length = int (h .headers .get ('Content-Length' , 0 ))
418- if content_length > 0 :
419- received_requests .append (h .rfile .read (content_length ))
420-
421- with http_fixture_server (inner ) as host :
422- client = _BaseClient ()
423-
424- def do ():
425- # Retries should reset the stream.
426- client .do ('POST' , f'{ host } /foo' , data = data )
427-
428- if test_case ._expected_exception :
429- expected_attempts_made = 1
430- with pytest .raises (DatabricksError ):
431- do ()
432- else :
433- expected_attempts_made = failure_count + 1
434- do ()
435-
436- assert received_requests == [test_case ._expected_result for _ in range (expected_attempts_made )]
437-
438-
439383class MockSession :
440384
441385 def __init__ (self , failure_count : int , failure_provider : Callable [[], Response ]):
@@ -444,7 +388,7 @@ def __init__(self, failure_count: int, failure_provider: Callable[[], Response])
444388 self ._failure_provider = failure_provider
445389
446390 @classmethod
447- def raise_retryable_exception (cls ):
391+ def raise_timeout_exception (cls ):
448392 raise Timeout ("Fake timeout" )
449393
450394 @classmethod
@@ -453,14 +397,16 @@ def return_retryable_response(cls):
453397 response = Response ()
454398 response ._content = b''
455399 response .status_code = 429
456- response .reason = 'OK'
400+ response .headers = {'Retry-After' : '1' }
401+ # response.reason = 'Too Many Requests'
457402 response .url = 'http://test.com/'
458403
459404 response .request = PreparedRequest ()
460405 response .request .url = response .url
461406 response .request .method = 'POST'
462407 response .request .headers = None
463408 response .request .body = b''
409+ return response
464410
465411 # following the signature of Session.request()
466412 def request (self ,
@@ -506,25 +452,41 @@ def request(self,
506452 return response
507453
508454
509- @pytest .mark .parametrize ('test_case' , retry_test_cases )
510- @pytest .mark .parametrize ('failure_provider' ,
511- [MockSession .raise_retryable_exception , MockSession .return_retryable_response ])
512- def test_rewind_seekable_stream (test_case : RetryTestCase , failure_provider : Callable [[], Response ]):
455+ @pytest .mark .parametrize (
456+ 'test_case' ,
457+ [
458+ # bytes -> BytesIO
459+ RetryTestCase (lambda : b"0123456789" , None , False , b"0123456789" ),
460+ # str -> BytesIO
461+ RetryTestCase (lambda : "0123456789" , None , False , b"0123456789" ),
462+ # BytesIO directly
463+ RetryTestCase (lambda : io .BytesIO (b"0123456789" ), None , False , b"0123456789" ),
464+ # BytesIO directly with offset
465+ RetryTestCase (lambda : io .BytesIO (b"0123456789" ), 4 , False , b"456789" ),
466+ # StringIO
467+ RetryTestCase (lambda : io .StringIO ("0123456789" ), None , False , b"0123456789" ),
468+ # Non-seekable
469+ RetryTestCase (lambda : RetryTestCase .create_non_seekable_stream (b"0123456789" ), None , True ,
470+ b"0123456789" )
471+ ])
472+ @pytest .mark .parametrize ('failure' , [[MockSession .raise_timeout_exception , Timeout ],
473+ [MockSession .return_retryable_response , errors .TooManyRequests ]])
474+ def test_rewind_seekable_stream (test_case : RetryTestCase , failure : Tuple [Callable [[], Response ], Type ]):
513475 failure_count = 2
514476
515477 data = test_case .get_data ()
516478
517- session = MockSession (failure_count , failure_provider )
479+ session = MockSession (failure_count , failure [ 0 ] )
518480 client = _BaseClient ()
519481 client ._session = session
520482
521483 def do ():
522- # Retries should reset the stream.
523484 client .do ('POST' , f'test.com/foo' , data = data )
524485
525- if test_case ._expected_exception :
486+ if test_case ._expected_failure :
526487 expected_attempts_made = 1
527- with pytest .raises (Timeout ):
488+ exception_class = failure [1 ]
489+ with pytest .raises (exception_class ):
528490 do ()
529491 else :
530492 expected_attempts_made = failure_count + 1
0 commit comments