Skip to content

Commit 7c64c7b

Browse files
committed
compact tests
Signed-off-by: Sai Shree Pradhan <[email protected]>
1 parent e7d2779 commit 7c64c7b

File tree

1 file changed

+13
-38
lines changed

1 file changed

+13
-38
lines changed

tests/e2e/test_telemetry_retry.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import pytest
22
from unittest.mock import patch, MagicMock
33
import io
4+
import time
45

56
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
6-
from databricks.sql.telemetry.models.event import DriverConnectionParameters, HostDetails, DatabricksClientType
7-
from databricks.sql.telemetry.models.enums import AuthMech
87
from databricks.sql.auth.retry import DatabricksRetryPolicy
98

109
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
@@ -64,40 +63,18 @@ def get_client(self, session_id, num_retries=3):
6463
adapter.max_retries = retry_policy
6564
return client, adapter
6665

67-
def test_success_no_retry(self):
68-
client, _ = self.get_client("session-success")
69-
params = DriverConnectionParameters(
70-
http_path="test-path", mode=DatabricksClientType.THRIFT,
71-
host_info=HostDetails(host_url="test.databricks.com", port=443),
72-
auth_mech=AuthMech.PAT
73-
)
74-
mock_responses = [{"status": 200}]
75-
76-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
77-
client.export_initial_telemetry_log(params, "test-agent")
78-
TelemetryClientFactory.close(client._session_id_hex)
79-
80-
mock_get_conn.return_value.getresponse.assert_called_once()
81-
client, _ = self.get_client("session-retry-once", num_retries=1)
82-
mock_responses = [{"status": 503}, {"status": 200}]
83-
84-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
85-
client.export_failure_log("TestError", "Test message")
86-
TelemetryClientFactory.close(client._session_id_hex)
87-
88-
assert mock_get_conn.return_value.getresponse.call_count == 2
89-
9066
@pytest.mark.parametrize(
9167
"status_code, description",
9268
[
9369
(401, "Unauthorized"),
9470
(403, "Forbidden"),
9571
(501, "Not Implemented"),
72+
(200, "Success"),
9673
],
9774
)
9875
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
9976
"""
100-
Verifies that terminal error codes (401, 403, 501, etc.) are not retried.
77+
Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried.
10178
"""
10279
# Use the status code in the session ID for easier debugging if it fails
10380
client, _ = self.get_client(f"session-{status_code}")
@@ -109,24 +86,22 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti
10986

11087
mock_get_conn.return_value.getresponse.assert_called_once()
11188

112-
def test_respects_retry_after_header(self):
113-
client, _ = self.get_client("session-retry-after", num_retries=1)
114-
mock_responses = [{"status": 429, "headers": {'Retry-After': '1'}}, {"status": 200}]
115-
116-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
117-
client.export_failure_log("TestError", "Test message")
118-
TelemetryClientFactory.close(client._session_id_hex)
119-
120-
assert mock_get_conn.return_value.getresponse.call_count == 2
121-
12289
def test_exceeds_retry_count_limit(self):
90+
"""
91+
Verifies that the client retries up to the specified number of times before giving up.
92+
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
93+
"""
12394
num_retries = 3
12495
expected_total_calls = num_retries + 1
96+
retry_after = 1
12597
client, _ = self.get_client("session-exceed-limit", num_retries=num_retries)
126-
mock_responses = [{"status": 503}] * expected_total_calls
98+
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]
12799

128100
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
101+
start_time = time.time()
129102
client.export_failure_log("TestError", "Test message")
130103
TelemetryClientFactory.close(client._session_id_hex)
104+
end_time = time.time()
131105

132-
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
106+
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
107+
assert end_time - start_time > retry_after

0 commit comments

Comments
 (0)