11import io
22import random
33from http .server import BaseHTTPRequestHandler
4- from typing import Iterator , List
4+ from typing import Callable , Iterator , List , Optional
55from unittest .mock import Mock
66
77import pytest
8+ from requests import PreparedRequest , Response , Timeout
89
910from databricks .sdk import errors , useragent
1011from databricks .sdk ._base_client import (_BaseClient , _RawResponse ,
@@ -62,7 +63,7 @@ def test_streaming_response_read_closes(config):
6263@pytest .mark .parametrize ('status_code,headers,body,expected_error' , [
6364 (400 , {}, {
6465 "message" :
65- "errorMessage" ,
66+ "errorMessage" ,
6667 "details" : [{
6768 "type" : DatabricksError ._error_info_type ,
6869 "reason" : "error reason" ,
@@ -103,9 +104,9 @@ def test_streaming_response_read_closes(config):
103104 (429 , {
104105 'Retry-After' : '100'
105106 }, {
106- 'error_code' : 'TOO_MANY_REQUESTS' ,
107- 'message' : 'errorMessage' ,
108- }, errors .TooManyRequests ('errorMessage' , error_code = 'TOO_MANY_REQUESTS' , retry_after_secs = 100 )),
107+ 'error_code' : 'TOO_MANY_REQUESTS' ,
108+ 'message' : 'errorMessage' ,
109+ }, errors .TooManyRequests ('errorMessage' , error_code = 'TOO_MANY_REQUESTS' , retry_after_secs = 100 )),
109110 (503 , {}, {
110111 'error_code' : 'TEMPORARILY_UNAVAILABLE' ,
111112 'message' : 'errorMessage' ,
@@ -114,9 +115,9 @@ def test_streaming_response_read_closes(config):
114115 (503 , {
115116 'Retry-After' : '100'
116117 }, {
117- 'error_code' : 'TEMPORARILY_UNAVAILABLE' ,
118- 'message' : 'errorMessage' ,
119- },
118+ 'error_code' : 'TEMPORARILY_UNAVAILABLE' ,
119+ 'message' : 'errorMessage' ,
120+ },
120121 errors .TemporarilyUnavailable ('errorMessage' , error_code = 'TEMPORARILY_UNAVAILABLE' ,
121122 retry_after_secs = 100 )),
122123 (404 , {}, {
@@ -357,91 +358,154 @@ def tell(self):
357358 assert client ._is_seekable_stream (CustomSeekableStream ())
358359
359360
360- @pytest .mark .parametrize (
361- 'input_data' ,
362- [
363- b"0123456789" , # bytes -> BytesIO
364- "0123456789" , # str -> BytesIO
365- io .BytesIO (b"0123456789" ), # BytesIO directly
366- io .StringIO ("0123456789" ), # StringIO
367- ])
368- def test_reset_seekable_stream_on_retry (input_data ):
369- received_data = []
361+ class RetryTestCase :
370362
371- # Retry two times before succeeding.
372- def inner (h : BaseHTTPRequestHandler ):
373- if len (received_data ) == 2 :
374- h .send_response (200 )
375- h .end_headers ()
376- else :
377- h .send_response (429 )
378- h .end_headers ()
363+ def __init__ (self , data_provider : Callable , offset : Optional [int ], expected_exception : bool ,
364+ expected_result : bytes ):
365+ self ._data_provider = data_provider
366+ self ._offset = offset
367+ self ._expected_result = expected_result
368+ self ._expected_exception = expected_exception
379369
380- content_length = int (h .headers .get ('Content-Length' , 0 ))
381- if content_length > 0 :
382- received_data .append (h .rfile .read (content_length ))
370+ def get_data (self ):
371+ data = self ._data_provider ()
372+ if self ._offset is not None :
373+ data .seek (self ._offset )
374+ return data
383375
384- with http_fixture_server (inner ) as host :
385- client = _BaseClient ()
376+ @classmethod
377+ def create_non_seekable_stream (cls , data : bytes ):
378+ result = io .BytesIO (data )
379+ result .seekable = lambda : False # makes the stream appear non-seekable
380+ return result
381+
382+
383+ retry_test_cases = [
384+ # bytes -> BytesIO
385+ RetryTestCase (lambda : b"0123456789" , None , False , b"0123456789" ),
386+ # str -> BytesIO
387+ RetryTestCase (lambda : "0123456789" , None , False , b"0123456789" ),
388+ # BytesIO directly
389+ RetryTestCase (lambda : io .BytesIO (b"0123456789" ), None , False , b"0123456789" ),
390+ # BytesIO directly with offset
391+ RetryTestCase (lambda : io .BytesIO (b"0123456789" ), 4 , False , b"456789" ),
392+ # StringIO
393+ RetryTestCase (lambda : io .StringIO ("0123456789" ), None , False , b"0123456789" ),
394+ # Non-seekable
395+ RetryTestCase (lambda : RetryTestCase .create_non_seekable_stream (b"0123456789" ),
396+ None , True , b"0123456789" )
397+ ]
386398
387- # Retries should reset the stream.
388- client .do ('POST' , f'{ host } /foo' , data = input_data )
389399
390- assert received_data == [b"0123456789" , b"0123456789" , b"0123456789" ]
400+ @pytest .mark .parametrize ('test_case' , retry_test_cases )
401+ def test_rewind_seekable_stream_on_retryable_error_response (test_case : RetryTestCase ):
402+ received_requests = []
391403
404+ data = test_case .get_data ()
392405
393- def test_reset_seekable_stream_to_their_initial_position_on_retry ():
394- received_data = []
406+ failure_count = 2
395407
396408 # Retry two times before succeeding.
397409 def inner (h : BaseHTTPRequestHandler ):
398- if len (received_data ) == 2 :
410+ if len (received_requests ) == failure_count :
399411 h .send_response (200 )
400412 h .end_headers ()
401413 else :
402414 h .send_response (429 )
415+ h .send_header ('Retry-After' , '1' ) # avoid a warning in log
403416 h .end_headers ()
404417
405418 content_length = int (h .headers .get ('Content-Length' , 0 ))
406419 if content_length > 0 :
407- received_data .append (h .rfile .read (content_length ))
408-
409- input_data = io .BytesIO (b"0123456789" )
410- input_data .seek (4 )
420+ received_requests .append (h .rfile .read (content_length ))
411421
412422 with http_fixture_server (inner ) as host :
413423 client = _BaseClient ()
414424
415- # Retries should reset the stream.
416- client .do ('POST' , f'{ host } /foo' , data = input_data )
425+ def do ():
426+ # Retries should reset the stream.
427+ client .do ('POST' , f'{ host } /foo' , data = data )
417428
418- assert received_data == [b"456789" , b"456789" , b"456789" ]
419- assert input_data .tell () == 10 # EOF
429+ if test_case ._expected_exception :
430+ expected_attempts_made = 1
431+ with pytest .raises (DatabricksError ):
432+ do ()
433+ else :
434+ expected_attempts_made = failure_count + 1
435+ do ()
436+
437+ assert received_requests == [test_case ._expected_result for _ in range (expected_attempts_made )]
438+
439+
440+ class MockSession :
441+
442+ def __init__ (self , failure_count : int ):
443+ self ._failure_count = failure_count
444+ self ._received_requests : List [bytes ] = []
445+
446+ # following the signature of Session.request()
447+ def request (self ,
448+ method ,
449+ url ,
450+ params = None ,
451+ data = None ,
452+ headers = None ,
453+ cookies = None ,
454+ files = None ,
455+ auth = None ,
456+ timeout = None ,
457+ allow_redirects = True ,
458+ proxies = None ,
459+ hooks = None ,
460+ stream = None ,
461+ verify = None ,
462+ cert = None ,
463+ json = None ):
464+ request_body = data .read ()
465+
466+ if isinstance (request_body , str ):
467+ request_body = request_body .encode ('utf-8' ) # to be able to compare with expected bytes
468+
469+ self ._received_requests .append (request_body )
470+ if self ._failure_count > 0 :
471+ self ._failure_count -= 1
472+ raise Timeout ("Fake timeout" ) # retryable error
473+ else :
474+ # fill response fields so that logging does not fail
475+ response = Response ()
476+ response ._content = b''
477+ response .status_code = 200
478+ response .reason = 'OK'
479+ response .url = url
420480
481+ response .request = PreparedRequest ()
482+ response .request .url = url
483+ response .request .method = method
484+ response .request .headers = headers
485+ response .request .body = data
486+ return response
421487
422- def test_no_retry_or_reset_on_non_seekable_stream ():
423- requests = []
424488
425- # Always respond with a response that triggers a retry.
426- def inner (h : BaseHTTPRequestHandler ):
427- content_length = int (h .headers .get ('Content-Length' , 0 ))
428- if content_length > 0 :
429- requests .append (h .rfile .read (content_length ))
489+ @pytest .mark .parametrize ('test_case' , retry_test_cases )
490+ def test_rewind_seekable_stream_on_retryable_exception (test_case : RetryTestCase ):
491+ failure_count = 2
430492
431- h .send_response (429 )
432- h .send_header ('Retry-After' , '1' )
433- h .end_headers ()
493+ data = test_case .get_data ()
434494
435- input_data = io .BytesIO (b"0123456789" )
436- input_data .seekable = lambda : False # makes the stream appear non-seekable
495+ session = MockSession (failure_count )
496+ client = _BaseClient ()
497+ client ._session = session
437498
438- with http_fixture_server (inner ) as host :
439- client = _BaseClient ()
499+ def do ():
500+ # Retries should reset the stream.
501+ client .do ('POST' , f'test.com/foo' , data = data )
440502
441- # Should raise error immediately without retry.
442- with pytest .raises (DatabricksError ):
443- client .do ('POST' , f'{ host } /foo' , data = input_data )
503+ if test_case ._expected_exception :
504+ expected_attempts_made = 1
505+ with pytest .raises (Timeout ):
506+ do ()
507+ else :
508+ expected_attempts_made = failure_count + 1
509+ do ()
444510
445- # Verify that only one request was made (no retries).
446- assert requests == [b"0123456789" ]
447- assert input_data .tell () == 10 # EOF
511+ assert session ._received_requests == [test_case ._expected_result for _ in range (expected_attempts_made )]
0 commit comments