11import pytest
22from unittest .mock import patch , MagicMock
33import io
4+ import time
45
56from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory
6- from databricks .sql .telemetry .models .event import DriverConnectionParameters , HostDetails , DatabricksClientType
7- from databricks .sql .telemetry .models .enums import AuthMech
87from databricks .sql .auth .retry import DatabricksRetryPolicy
98
109PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
@@ -64,40 +63,18 @@ def get_client(self, session_id, num_retries=3):
6463 adapter .max_retries = retry_policy
6564 return client , adapter
6665
67- def test_success_no_retry (self ):
68- client , _ = self .get_client ("session-success" )
69- params = DriverConnectionParameters (
70- http_path = "test-path" , mode = DatabricksClientType .THRIFT ,
71- host_info = HostDetails (host_url = "test.databricks.com" , port = 443 ),
72- auth_mech = AuthMech .PAT
73- )
74- mock_responses = [{"status" : 200 }]
75-
76- with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
77- client .export_initial_telemetry_log (params , "test-agent" )
78- TelemetryClientFactory .close (client ._session_id_hex )
79-
80- mock_get_conn .return_value .getresponse .assert_called_once ()
81- client , _ = self .get_client ("session-retry-once" , num_retries = 1 )
82- mock_responses = [{"status" : 503 }, {"status" : 200 }]
83-
84- with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
85- client .export_failure_log ("TestError" , "Test message" )
86- TelemetryClientFactory .close (client ._session_id_hex )
87-
88- assert mock_get_conn .return_value .getresponse .call_count == 2
89-
9066 @pytest .mark .parametrize (
9167 "status_code, description" ,
9268 [
9369 (401 , "Unauthorized" ),
9470 (403 , "Forbidden" ),
9571 (501 , "Not Implemented" ),
72+ (200 , "Success" ),
9673 ],
9774 )
9875 def test_non_retryable_status_codes_are_not_retried (self , status_code , description ):
9976 """
100- Verifies that terminal error codes (401, 403, 501, etc. ) are not retried.
77+ Verifies that terminal error codes (401, 403, 501) and success codes (200 ) are not retried.
10178 """
10279 # Use the status code in the session ID for easier debugging if it fails
10380 client , _ = self .get_client (f"session-{ status_code } " )
@@ -109,24 +86,22 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti
10986
11087 mock_get_conn .return_value .getresponse .assert_called_once ()
11188
112- def test_respects_retry_after_header (self ):
113- client , _ = self .get_client ("session-retry-after" , num_retries = 1 )
114- mock_responses = [{"status" : 429 , "headers" : {'Retry-After' : '1' }}, {"status" : 200 }]
115-
116- with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
117- client .export_failure_log ("TestError" , "Test message" )
118- TelemetryClientFactory .close (client ._session_id_hex )
119-
120- assert mock_get_conn .return_value .getresponse .call_count == 2
121-
12289 def test_exceeds_retry_count_limit (self ):
90+ """
91+ Verifies that the client retries up to the specified number of times before giving up.
92+ Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
93+ """
12394 num_retries = 3
12495 expected_total_calls = num_retries + 1
96+ retry_after = 1
12597 client , _ = self .get_client ("session-exceed-limit" , num_retries = num_retries )
126- mock_responses = [{"status" : 503 }] * expected_total_calls
98+ mock_responses = [{"status" : 503 , "headers" : { "Retry-After" : str ( retry_after )}}, { "status" : 429 }, { "status" : 502 }, { "status" : 503 }]
12799
128100 with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
101+ start_time = time .time ()
129102 client .export_failure_log ("TestError" , "Test message" )
130103 TelemetryClientFactory .close (client ._session_id_hex )
104+ end_time = time .time ()
131105
132- assert mock_get_conn .return_value .getresponse .call_count == expected_total_calls
106+ assert mock_get_conn .return_value .getresponse .call_count == expected_total_calls
107+ assert end_time - start_time > retry_after
0 commit comments