Skip to content

Commit 6bb5033

Browse files
tvaron3Copilot
andauthored
Health Check (Azure#43339)
* change workloads based on feedback * add staging yml file * add staging yml file * change health checks to use health probe * remove unnecessary return type of database accounts method * remove unused imports * add more tests and fix tests * fix pylint, changelog, and test * fix tests * fix tests * fix tests * move to use only one thread pool executor * fix tests * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * Update sdk/cosmos/azure-cosmos/tests/test_health_check.py Co-authored-by: Copilot <[email protected]> * Update sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py Co-authored-by: Copilot <[email protected]> * react to design meeting * remove configurability for database account calls * fix pylint * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix test * increase retry factor and max retry after, updated documentation * react to comments * fix test pipeline issue and upgrade to pypy311 --------- Co-authored-by: Copilot <[email protected]>
1 parent f6a767c commit 6bb5033

31 files changed

+335
-287
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
* Fixed bug where customer provided excluded region was not always being honored during certain transient failures. See [PR 43602](https://github.com/Azure/azure-sdk-for-python/pull/43602)
1111

1212
#### Other Changes
13+
* Further optimized health checks for sync and async clients. See [PR 43339](https://github.com/Azure/azure-sdk-for-python/pull/43339)
1314
* Enhanced logging to ensure when a region is marked unavailable we have the proper context. See [PR 43602](https://github.com/Azure/azure-sdk-for-python/pull/43602)
1415

1516
### 4.14.0 (2025-10-13)

sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class _Constants:
4040
Name: Literal["name"] = "name"
4141
DatabaseAccountEndpoint: Literal["databaseAccountEndpoint"] = "databaseAccountEndpoint"
4242
DefaultEndpointsRefreshTime: int = 5 * 60 * 1000 # milliseconds
43-
UnavailableEndpointDBATimeouts: int = 1 # seconds
4443

4544
# ServiceDocument Resource
4645
EnableMultipleWritableLocations: Literal["enableMultipleWriteLocations"] = "enableMultipleWriteLocations"
@@ -61,11 +60,11 @@ class _Constants:
6160
INFERENCE_SERVICE_DEFAULT_SCOPE = "https://dbinference.azure.com/.default"
6261
SEMANTIC_RERANKER_INFERENCE_ENDPOINT: str = "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT"
6362

64-
# Database Account Retry Policy constants
63+
# Health Check Retry Policy constants
6564
AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES"
6665
AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES_DEFAULT: int = 3
6766
AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS: str = "AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS"
68-
AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS_DEFAULT: int = 100
67+
AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS_DEFAULT: int = 500
6968

7069
# Only applicable when circuit breaker is enabled -------------------------
7170
CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ"

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__( # pylint: disable=too-many-statements
248248
# Routing map provider
249249
self._routing_map_provider = routing_map_provider.SmartRoutingMapProvider(self)
250250

251-
database_account, _ = self._global_endpoint_manager._GetDatabaseAccount(**kwargs)
251+
database_account = self._global_endpoint_manager._GetDatabaseAccount(**kwargs)
252252
self._global_endpoint_manager.force_refresh_on_startup(database_account)
253253

254254
# Use database_account if no consistency passed in to verify consistency level to be used
@@ -2691,22 +2691,21 @@ def GetDatabaseAccount(
26912691
response_hook(last_response_headers, result)
26922692
return database_account
26932693

2694-
def _GetDatabaseAccountCheck(
2694+
def health_check(
26952695
self,
26962696
url_connection: Optional[str] = None,
26972697
**kwargs: Any
26982698
):
2699-
"""Gets database account info.
2699+
""" Send a request to check the health of region.
27002700
2701-
:param str url_connection: the endpoint used to get the database account
2702-
:return: The Database Account.
2703-
:rtype: documents.DatabaseAccount
2701+
:param str url_connection: the endpoint for the region being checked
27042702
"""
27052703
if url_connection is None:
27062704
url_connection = self.url_connection
27072705

27082706
headers = base.GetHeaders(self, self.default_headers, "get", "", "", "",
2709-
documents._OperationType.Read,{}, client_id=self.client_id)
2707+
documents._OperationType.Read, {},
2708+
client_id=self.client_id)
27102709
request_params = RequestObject(http_constants.ResourceType.DatabaseAccount,
27112710
documents._OperationType.Read,
27122711
headers,

sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py

Lines changed: 62 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
"""Internal class for global endpoint manager implementation in the Azure Cosmos
2323
database service.
2424
"""
25-
25+
import logging
26+
import os
2627
import threading
27-
from typing import Tuple
28+
from concurrent.futures import ThreadPoolExecutor, as_completed
29+
from typing import Callable, Any
2830

2931
from azure.core.exceptions import AzureError
3032

@@ -37,7 +39,7 @@
3739

3840

3941
# pylint: disable=protected-access
40-
42+
logger = logging.getLogger("azure.cosmos._GlobalEndpointManager")
4143

4244
class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attributes
4345
"""
@@ -58,6 +60,9 @@ def __init__(self, client):
5860
self.refresh_lock = threading.RLock()
5961
self.last_refresh_time = 0
6062
self._database_account_cache = None
63+
self.startup = True
64+
self._refresh_thread = None
65+
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count())
6166

6267
def get_refresh_time_interval_in_ms_stub(self):
6368
return constants._Constants.DefaultEndpointsRefreshTime
@@ -120,31 +125,58 @@ def refresh_endpoint_list(self, database_account, **kwargs):
120125
raise e
121126

122127
def _refresh_endpoint_list_private(self, database_account=None, **kwargs):
123-
if database_account:
128+
# 1. If explicit database_account provided and not during startup, just update cache (no health check now)
129+
# 2. Else if refresh criteria met:
130+
# a. If not startup -> spawn background thread to do full database account + health checks
131+
# b. If startup -> get database account synchronously, then spawn background health checks,
132+
# then mark startup False
133+
if database_account and not self.startup:
124134
self.location_cache.perform_on_database_account_read(database_account)
125135
self.refresh_needed = False
126136
self.last_refresh_time = current_time_millis()
127137
else:
128138
if self.location_cache.should_refresh_endpoints() or self.refresh_needed:
129139
self.refresh_needed = False
130140
self.last_refresh_time = current_time_millis()
131-
# this will perform getDatabaseAccount calls to check endpoint health
132-
self._endpoints_health_check(**kwargs)
133-
134-
def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
141+
if not self.startup:
142+
# background full refresh (database account + health checks)
143+
self._start_background_refresh(self._refresh_database_account_and_health, kwargs)
144+
else:
145+
self.location_cache.perform_on_database_account_read(database_account)
146+
self._start_background_refresh(self._endpoints_health_check, kwargs)
147+
self.startup = False
148+
149+
def _start_background_refresh(self, target: Callable[..., None], kwargs: dict[str, Any]):
150+
"""Starts a daemon thread to run the given target if one is not already active.
151+
:param Callable target: The function to run in the background thread.
152+
:param dict kwargs: The keyword arguments to pass to the target function.
153+
"""
154+
if not (self._refresh_thread and self._refresh_thread.is_alive()):
155+
def runner():
156+
try:
157+
target(**kwargs)
158+
except Exception as exception: #pylint: disable=broad-exception-caught
159+
# background failures should not crash main thread
160+
# Intentionally swallow to avoid affecting foreground; logging could be added.
161+
logger.error("Health check task failed: %s", exception, exc_info=True)
162+
t = threading.Thread(target=runner, name="cosmos-endpoint-refresh", daemon=True)
163+
self._refresh_thread = t
164+
t.start()
165+
166+
def _GetDatabaseAccount(self, **kwargs) -> DatabaseAccount:
135167
"""Gets the database account.
136168
137169
First tries by using the default endpoint, and if that doesn't work,
138170
use the endpoints for the preferred locations in the order they are
139171
specified, to get the database account.
140172
:returns: A `DatabaseAccount` instance representing the Cosmos DB Database Account
141173
and the endpoint that was used for the request.
142-
:rtype: tuple of (~azure.cosmos.DatabaseAccount, str)
174+
:rtype: ~azure.cosmos.DatabaseAccount
143175
"""
144176
try:
145177
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint, **kwargs)
146178
self._database_account_cache = database_account
147-
return database_account, self.DefaultEndpoint
179+
return database_account
148180
# If for any reason(non-globaldb related), we are not able to get the database
149181
# account from the above call to GetDatabaseAccount, we would try to get this
150182
# information from any of the preferred locations that the user might have
@@ -157,52 +189,34 @@ def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
157189
try:
158190
database_account = self._GetDatabaseAccountStub(locational_endpoint, **kwargs)
159191
self._database_account_cache = database_account
160-
self.location_cache.mark_endpoint_available(locational_endpoint)
161-
return database_account, locational_endpoint
192+
return database_account
162193
except (exceptions.CosmosHttpResponseError, AzureError):
163194
self._mark_endpoint_unavailable(locational_endpoint, "_GetDatabaseAccount")
164195
raise
165196

166197
def _endpoints_health_check(self, **kwargs):
167-
"""Gets the database account for each endpoint.
168-
169-
Validating if the endpoint is healthy else marking it as unavailable.
170-
"""
171-
database_account, attempted_endpoint = self._GetDatabaseAccount(**kwargs)
172-
self.location_cache.perform_on_database_account_read(database_account)
173-
# get all the regional routing contexts to check
198+
"""Performs concurrent health checks for each endpoint (background-safe)."""
174199
endpoints = self.location_cache.endpoints_to_health_check()
175-
success_count = 0
176-
for endpoint in endpoints:
177-
if endpoint != attempted_endpoint:
178-
if success_count >= 4:
179-
break
180-
# save current dba timeouts
181-
previous_dba_read_timeout = self.client.connection_policy.DBAReadTimeout
182-
previous_dba_connection_timeout = self.client.connection_policy.DBAConnectionTimeout
183-
try:
184-
if (endpoint in
185-
self.location_cache.location_unavailability_info_by_endpoint):
186-
# if the endpoint is unavailable, we need to lower the timeouts to be more aggressive in the
187-
# health check. This helps reduce the time the health check is blocking all requests.
188-
self.client.connection_policy._override_dba_timeouts(constants._Constants
189-
.UnavailableEndpointDBATimeouts,
190-
constants._Constants
191-
.UnavailableEndpointDBATimeouts)
192-
self.client._GetDatabaseAccountCheck(endpoint, **kwargs)
193-
else:
194-
self.client._GetDatabaseAccountCheck(endpoint, **kwargs)
195-
success_count += 1
196-
self.location_cache.mark_endpoint_available(endpoint)
197-
except (exceptions.CosmosHttpResponseError, AzureError):
198-
self._mark_endpoint_unavailable(endpoint, "_endpoints_health_check")
199200

200-
finally:
201-
# after the health check for that endpoint setting the timeouts back to their original values
202-
self.client.connection_policy._override_dba_timeouts(previous_dba_read_timeout,
203-
previous_dba_connection_timeout)
201+
def _health_check(endpoint: str):
202+
try:
203+
self.client.health_check(endpoint, **kwargs)
204+
self.location_cache.mark_endpoint_available(endpoint)
205+
except (exceptions.CosmosHttpResponseError, AzureError):
206+
self._mark_endpoint_unavailable(endpoint, "_endpoints_health_check")
207+
208+
futures = [self.executor.submit(_health_check, ep) for ep in endpoints]
209+
for f in as_completed(futures):
210+
# propagate unexpected exceptions (should be none besides those swallowed in health check)
211+
_ = f.result()
212+
# After all probes, update cache once
204213
self.location_cache.update_location_cache()
205214

215+
def _refresh_database_account_and_health(self, **kwargs):
216+
database_account = self._GetDatabaseAccount(**kwargs)
217+
self.location_cache.perform_on_database_account_read(database_account)
218+
self._endpoints_health_check(**kwargs)
219+
206220
def _GetDatabaseAccountStub(self, endpoint, **kwargs):
207221
"""Stub for getting database account from the client.
208222
This can be used for mocking purposes as well.
@@ -211,21 +225,4 @@ def _GetDatabaseAccountStub(self, endpoint, **kwargs):
211225
:returns: A `DatabaseAccount` instance representing the Cosmos DB Database Account.
212226
:rtype: ~azure.cosmos.DatabaseAccount
213227
"""
214-
if endpoint in self.location_cache.location_unavailability_info_by_endpoint:
215-
previous_dba_read_timeout = self.client.connection_policy.DBAReadTimeout
216-
previous_dba_connection_timeout = self.client.connection_policy.DBAConnectionTimeout
217-
try:
218-
# if the endpoint is unavailable, we need to lower the timeouts to be more aggressive in the
219-
# health check. This helps reduce the time the health check is blocking all requests.
220-
self.client.connection_policy._override_dba_timeouts(constants._Constants
221-
.UnavailableEndpointDBATimeouts,
222-
constants._Constants
223-
.UnavailableEndpointDBATimeouts)
224-
database_account = self.client.GetDatabaseAccount(endpoint, **kwargs)
225-
finally:
226-
# after the health check for that endpoint setting the timeouts back to their original values
227-
self.client.connection_policy._override_dba_timeouts(previous_dba_read_timeout,
228-
previous_dba_connection_timeout)
229-
else:
230-
database_account = self.client.GetDatabaseAccount(endpoint, **kwargs)
231-
return database_account
228+
return self.client.GetDatabaseAccount(endpoint, **kwargs)

sdk/cosmos/azure-cosmos/azure/cosmos/_database_account_retry_policy.py renamed to sdk/cosmos/azure-cosmos/azure/cosmos/_health_check_retry_policy.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,17 @@
1919
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2020
# SOFTWARE.
2121

22-
"""Internal class for database account retry policy implementation in the
22+
"""Internal class for health check retry policy implementation in the
2323
Azure Cosmos database service.
2424
"""
2525
import os
26-
from azure.core.exceptions import ServiceRequestError, ServiceResponseError
2726
from azure.cosmos import _constants
2827

2928

30-
class DatabaseAccountRetryPolicy(object):
31-
"""Implements retry logic for database account reads in Azure Cosmos DB."""
29+
class HealthCheckRetryPolicy(object):
30+
"""Implements retry logic for health checks in Azure Cosmos DB."""
3231

33-
# List of HTTP status codes considered transient errors for retry logic.
34-
transient_status_codes = [502, 503, 504]
35-
36-
# Tuple of exception types considered transient errors for retry logic.
37-
transient_exceptions = (ServiceRequestError, ServiceResponseError)
38-
39-
def __init__(self, connection_policy):
32+
def __init__(self, connection_policy, *args):
4033
self.retry_count = 0
4134
self.retry_after_in_milliseconds = int(os.getenv(
4235
_constants._Constants.AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS,
@@ -47,8 +40,12 @@ def __init__(self, connection_policy):
4740
str(_constants._Constants.AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES_DEFAULT)
4841
))
4942
self.connection_policy = connection_policy
43+
self.retry_factor = 2
44+
self.max_retry_after_in_milliseconds = 1000 * 60 * 3 # 3 minutes
45+
self.initial_connection_timeout = 5
46+
self.request = args[0] if args else None
5047

51-
def ShouldRetry(self, exception):
48+
def ShouldRetry(self, exception):# pylint: disable=unused-argument
5249
"""
5350
Determines if the given exception is transient and if a retry should be attempted.
5451
@@ -57,19 +54,20 @@ def ShouldRetry(self, exception):
5754
:return: True if the exception is transient and retry attempts to remain, False otherwise.
5855
:rtype: bool
5956
"""
60-
61-
is_transient = False
62-
63-
# Check for transient HTTP status codes
64-
status_code = getattr(exception, "status_code", None)
65-
if status_code in self.transient_status_codes:
66-
is_transient = True
67-
68-
# Check for transient exception types
69-
if isinstance(exception, self.transient_exceptions):
70-
is_transient = True
71-
72-
if is_transient and self.retry_count < self.max_retry_attempt_count:
57+
if self.retry_count > 0:
58+
self.retry_after_in_milliseconds = min(self.retry_after_in_milliseconds +
59+
self.retry_factor ** self.retry_count,
60+
self.max_retry_after_in_milliseconds)
61+
if self.request:
62+
# increase read timeout for each retry
63+
if self.request.read_timeout_override:
64+
self.request.read_timeout_override = min(self.request.read_timeout_override ** 2,
65+
self.connection_policy.ReadTimeout)
66+
else:
67+
self.request.read_timeout_override = self.initial_connection_timeout
68+
69+
70+
if self.retry_count < self.max_retry_attempt_count:
7371
self.retry_count += 1
7472
return True
7573

sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
self.excluded_locations: Optional[list[str]] = None
4747
self.excluded_locations_circuit_breaker: list[str] = []
4848
self.healthy_tentative_location: Optional[str] = None
49+
self.read_timeout_override: Optional[int] = None
4950
self.pk_val = pk_val
5051
self.retry_write: int = 0
5152

0 commit comments

Comments
 (0)