@@ -63,7 +63,7 @@ def test_streaming_response_read_closes(config):
6363@pytest .mark .parametrize ('status_code,headers,body,expected_error' , [
6464 (400 , {}, {
6565 "message" :
66- "errorMessage" ,
66+ "errorMessage" ,
6767 "details" : [{
6868 "type" : DatabricksError ._error_info_type ,
6969 "reason" : "error reason" ,
@@ -104,9 +104,9 @@ def test_streaming_response_read_closes(config):
104104 (429 , {
105105 'Retry-After' : '100'
106106 }, {
107- 'error_code' : 'TOO_MANY_REQUESTS' ,
108- 'message' : 'errorMessage' ,
109- }, errors .TooManyRequests ('errorMessage' , error_code = 'TOO_MANY_REQUESTS' , retry_after_secs = 100 )),
107+ 'error_code' : 'TOO_MANY_REQUESTS' ,
108+ 'message' : 'errorMessage' ,
109+ }, errors .TooManyRequests ('errorMessage' , error_code = 'TOO_MANY_REQUESTS' , retry_after_secs = 100 )),
110110 (503 , {}, {
111111 'error_code' : 'TEMPORARILY_UNAVAILABLE' ,
112112 'message' : 'errorMessage' ,
@@ -115,9 +115,9 @@ def test_streaming_response_read_closes(config):
115115 (503 , {
116116 'Retry-After' : '100'
117117 }, {
118- 'error_code' : 'TEMPORARILY_UNAVAILABLE' ,
119- 'message' : 'errorMessage' ,
120- },
118+ 'error_code' : 'TEMPORARILY_UNAVAILABLE' ,
119+ 'message' : 'errorMessage' ,
120+ },
121121 errors .TemporarilyUnavailable ('errorMessage' , error_code = 'TEMPORARILY_UNAVAILABLE' ,
122122 retry_after_secs = 100 )),
123123 (404 , {}, {
@@ -392,8 +392,7 @@ def create_non_seekable_stream(cls, data: bytes):
392392 # StringIO
393393 RetryTestCase (lambda : io .StringIO ("0123456789" ), None , False , b"0123456789" ),
394394 # Non-seekable
395- RetryTestCase (lambda : RetryTestCase .create_non_seekable_stream (b"0123456789" ),
396- None , True , b"0123456789" )
395+ RetryTestCase (lambda : RetryTestCase .create_non_seekable_stream (b"0123456789" ), None , True , b"0123456789" )
397396]
398397
399398
@@ -439,9 +438,29 @@ def do():
439438
440439class MockSession :
441440
442- def __init__ (self , failure_count : int ):
441+ def __init__ (self , failure_count : int , failure_provider : Callable [[], Response ] ):
443442 self ._failure_count = failure_count
444443 self ._received_requests : List [bytes ] = []
444+ self ._failure_provider = failure_provider
445+
446+ @classmethod
447+ def raise_retryable_exception (cls ):
448+ raise Timeout ("Fake timeout" )
449+
450+ @classmethod
451+ def return_retryable_response (cls ):
452+ # fill response fields so that logging does not fail
453+ response = Response ()
454+ response ._content = b''
455+ response .status_code = 429
456+ response .reason = 'OK'
457+ response .url = 'http://test.com/'
458+
459+ response .request = PreparedRequest ()
460+ response .request .url = response .url
461+ response .request .method = 'POST'
462+ response .request .headers = None
463+ response .request .body = b''
445464
446465 # following the signature of Session.request()
447466 def request (self ,
@@ -469,7 +488,8 @@ def request(self,
469488 self ._received_requests .append (request_body )
470489 if self ._failure_count > 0 :
471490 self ._failure_count -= 1
472- raise Timeout ("Fake timeout" ) # retryable error
491+ return self ._failure_provider ()
492+ #
473493 else :
474494 # fill response fields so that logging does not fail
475495 response = Response ()
@@ -487,12 +507,14 @@ def request(self,
487507
488508
489509@pytest .mark .parametrize ('test_case' , retry_test_cases )
490- def test_rewind_seekable_stream_on_retryable_exception (test_case : RetryTestCase ):
510+ @pytest .mark .parametrize ('failure_provider' ,
511+ [MockSession .raise_retryable_exception , MockSession .return_retryable_response ])
512+ def test_rewind_seekable_stream (test_case : RetryTestCase , failure_provider : Callable [[], Response ]):
491513 failure_count = 2
492514
493515 data = test_case .get_data ()
494516
495- session = MockSession (failure_count )
517+ session = MockSession (failure_count , failure_provider )
496518 client = _BaseClient ()
497519 client ._session = session
498520
0 commit comments