Skip to content

Commit 1d3c41a

Browse files
dibahlfiCopilot
andauthored
DatabaseAccountRetryPolicy Improvements (#42525)
* DatabaseAccountRetry - initial commit * Update sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py Co-authored-by: Copilot <[email protected]> * DatabaseAccountRetry - adding comments * fixing pylink comments * updated CHANGELOG.md * fixing bug for write region unavailability * fix: comments * fix: comments * fix: comments * fix: change log update --------- Co-authored-by: Copilot <[email protected]>
1 parent a298356 commit 1d3c41a

9 files changed

+321
-30
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#### Breaking Changes
99

1010
#### Bugs Fixed
11+
* Improved the resilience of Database Account Read metadata operation against short-lived network issues by increasing number of retries. See [PR 42525](https://github.com/Azure/azure-sdk-for-python/pull/42525).
12+
* Fixed bug where during health checks read regions were marked as unavailable for write operations. See [PR 42525](https://github.com/Azure/azure-sdk-for-python/pull/42525).
1113

1214
#### Other Changes
1315
* Added session token false progress merge logic. See [42393](https://github.com/Azure/azure-sdk-for-python/pull/42393)

sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ class _Constants:
5656
CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"
5757
CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False"
5858
AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"
59+
60+
# Database Account Retry Policy constants
61+
AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES"
62+
AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES_DEFAULT: int = 3
63+
AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS: str = "AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS"
64+
AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS_DEFAULT: int = 100
65+
5966
# Only applicable when circuit breaker is enabled -------------------------
6067
CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ"
6168
CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT: int = 10

sdk/cosmos/azure-cosmos/azure/cosmos/_database_account_retry_policy.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,55 @@
2222
"""Internal class for database account retry policy implementation in the
2323
Azure Cosmos database service.
2424
"""
25+
import os
26+
from azure.core.exceptions import ServiceRequestError, ServiceResponseError
27+
from azure.cosmos import _constants
28+
2529

2630
class DatabaseAccountRetryPolicy(object):
27-
"""The database account retry policy which should only retry once regardless of errors.
28-
"""
31+
"""Implements retry logic for database account reads in Azure Cosmos DB."""
32+
33+
# List of HTTP status codes considered transient errors for retry logic.
34+
transient_status_codes = [502, 503, 504]
35+
36+
# Tuple of exception types considered transient errors for retry logic.
37+
transient_exceptions = (ServiceRequestError, ServiceResponseError)
2938

3039
def __init__(self, connection_policy):
3140
self.retry_count = 0
32-
self.retry_after_in_milliseconds = 0
33-
self.max_retry_attempt_count = 1
41+
self.retry_after_in_milliseconds = int(os.getenv(
42+
_constants._Constants.AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS,
43+
str(_constants._Constants.AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS_DEFAULT)
44+
))
45+
self.max_retry_attempt_count = int(os.getenv(
46+
_constants._Constants.AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES,
47+
str(_constants._Constants.AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES_DEFAULT)
48+
))
3449
self.connection_policy = connection_policy
3550

36-
def ShouldRetry(self, exception): # pylint: disable=unused-argument
37-
"""Returns true if the request should retry based on the passed-in exception.
51+
def ShouldRetry(self, exception):
52+
"""
53+
Determines if the given exception is transient and if a retry should be attempted.
3854
39-
:param exceptions.CosmosHttpResponseError exception:
40-
:returns: a boolean stating whether the request should be retried
55+
:param exception: The exception instance to evaluate.
56+
:type exception: Exception
57+
:return: True if the exception is transient and retry attempts to remain, False otherwise.
4158
:rtype: bool
4259
"""
4360

44-
if self.retry_count >= self.max_retry_attempt_count:
45-
return False
61+
is_transient = False
62+
63+
# Check for transient HTTP status codes
64+
status_code = getattr(exception, "status_code", None)
65+
if status_code in self.transient_status_codes:
66+
is_transient = True
67+
68+
# Check for transient exception types
69+
if isinstance(exception, self.transient_exceptions):
70+
is_transient = True
4671

47-
self.retry_count += 1
72+
if is_transient and self.retry_count < self.max_retry_attempt_count:
73+
self.retry_count += 1
74+
return True
4875

49-
return True
76+
return False

sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ def force_refresh_on_startup(self, database_account):
9696
def update_location_cache(self):
9797
self.location_cache.update_location_cache()
9898

99+
def _mark_endpoint_unavailable(self, endpoint: str):
100+
"""Marks an endpoint as unavailable for the appropriate operations.
101+
:param str endpoint: The endpoint to mark as unavailable.
102+
"""
103+
write_endpoints = self.location_cache.get_all_write_endpoints()
104+
self.mark_endpoint_unavailable_for_read(endpoint, False)
105+
if endpoint in write_endpoints:
106+
self.mark_endpoint_unavailable_for_write(endpoint, False)
107+
99108
def refresh_endpoint_list(self, database_account, **kwargs):
100109
if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms:
101110
self.refresh_needed = True
@@ -145,8 +154,7 @@ def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
145154
except (exceptions.CosmosHttpResponseError, AzureError):
146155
# when atm is available, L: 145, 146 should be removed as the global endpoint shouldn't be used
147156
# for dataplane operations anymore
148-
self.mark_endpoint_unavailable_for_read(self.DefaultEndpoint, False)
149-
self.mark_endpoint_unavailable_for_write(self.DefaultEndpoint, False)
157+
self._mark_endpoint_unavailable(self.DefaultEndpoint)
150158
for location_name in self.PreferredLocations:
151159
locational_endpoint = LocationCache.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
152160
try:
@@ -155,8 +163,7 @@ def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
155163
self.location_cache.mark_endpoint_available(locational_endpoint)
156164
return database_account, locational_endpoint
157165
except (exceptions.CosmosHttpResponseError, AzureError):
158-
self.mark_endpoint_unavailable_for_read(locational_endpoint, False)
159-
self.mark_endpoint_unavailable_for_write(locational_endpoint, False)
166+
self._mark_endpoint_unavailable(locational_endpoint)
160167
raise
161168

162169
def _endpoints_health_check(self, **kwargs):
@@ -191,8 +198,8 @@ def _endpoints_health_check(self, **kwargs):
191198
success_count += 1
192199
self.location_cache.mark_endpoint_available(endpoint)
193200
except (exceptions.CosmosHttpResponseError, AzureError):
194-
self.mark_endpoint_unavailable_for_read(endpoint, False)
195-
self.mark_endpoint_unavailable_for_write(endpoint, False)
201+
self._mark_endpoint_unavailable(endpoint)
202+
196203
finally:
197204
# after the health check for that endpoint setting the timeouts back to their original values
198205
self.client.connection_policy.override_dba_timeouts(previous_dba_read_timeout,

sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,13 @@ def perform_on_database_account_read(self, database_account):
203203
database_account._EnableMultipleWritableLocations,
204204
)
205205

206+
def get_all_write_endpoints(self) -> Set[str]:
207+
return {
208+
endpoint
209+
for context in self.get_write_regional_routing_contexts()
210+
for endpoint in (context.get_primary(), context.get_alternate())
211+
}
212+
206213
def get_ordered_write_locations(self):
207214
return self.account_write_locations
208215

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ async def force_refresh_on_startup(self, database_account):
100100
def update_location_cache(self):
101101
self.location_cache.update_location_cache()
102102

103+
def _mark_endpoint_unavailable(self, endpoint: str):
104+
"""Marks an endpoint as unavailable for the appropriate operations.
105+
:param str endpoint: The endpoint to mark as unavailable.
106+
"""
107+
write_endpoints = self.location_cache.get_all_write_endpoints()
108+
self.mark_endpoint_unavailable_for_read(endpoint, False)
109+
if endpoint in write_endpoints:
110+
self.mark_endpoint_unavailable_for_write(endpoint, False)
111+
103112
async def refresh_endpoint_list(self, database_account, **kwargs):
104113
if self.refresh_task and self.refresh_task.done():
105114
try:
@@ -142,8 +151,7 @@ async def _database_account_check(self, endpoint: str, **kwargs: Dict[str, Any])
142151
await self.client._GetDatabaseAccountCheck(endpoint, **kwargs)
143152
self.location_cache.mark_endpoint_available(endpoint)
144153
except (exceptions.CosmosHttpResponseError, AzureError):
145-
self.mark_endpoint_unavailable_for_read(endpoint, False)
146-
self.mark_endpoint_unavailable_for_write(endpoint, False)
154+
self._mark_endpoint_unavailable(endpoint)
147155

148156
async def _endpoints_health_check(self, **kwargs):
149157
"""Gets the database account for each endpoint.
@@ -185,8 +193,7 @@ async def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
185193
# until we get the database account and return None at the end, if we are not able
186194
# to get that info from any endpoints
187195
except (exceptions.CosmosHttpResponseError, AzureError):
188-
self.mark_endpoint_unavailable_for_read(self.DefaultEndpoint, False)
189-
self.mark_endpoint_unavailable_for_write(self.DefaultEndpoint, False)
196+
self._mark_endpoint_unavailable(self.DefaultEndpoint)
190197
for location_name in self.PreferredLocations:
191198
locational_endpoint = LocationCache.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
192199
try:
@@ -195,8 +202,7 @@ async def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
195202
self.location_cache.mark_endpoint_available(locational_endpoint)
196203
return database_account, locational_endpoint
197204
except (exceptions.CosmosHttpResponseError, AzureError):
198-
self.mark_endpoint_unavailable_for_read(locational_endpoint, False)
199-
self.mark_endpoint_unavailable_for_write(locational_endpoint, False)
205+
self._mark_endpoint_unavailable(locational_endpoint)
200206
raise
201207

202208
async def _GetDatabaseAccountStub(self, endpoint, **kwargs):

sdk/cosmos/azure-cosmos/tests/test_health_check_async.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,34 @@ async def test_health_check_failure_async(self, setup, preferred_location, use_w
186186
try:
187187
setup[COLLECTION].client_connection._global_endpoint_manager.startup = False
188188
setup[COLLECTION].client_connection._global_endpoint_manager.refresh_needed = True
189-
for i in range(2):
190-
await setup[COLLECTION].create_item(body={'id': 'item' + str(uuid.uuid4()), 'pk': 'pk'})
191-
# wait for background task to finish
192-
await asyncio.sleep(2)
189+
# Trigger the background health check
190+
await setup[COLLECTION].create_item(body={'id': 'item' + str(uuid.uuid4()), 'pk': 'pk'})
191+
192+
# Poll until the background task marks the endpoints as unavailable
193+
start_time = time.time()
194+
while (len(setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint) < len(
195+
REGIONS) and time.time() - start_time < 10):
196+
await asyncio.sleep(0.1)
193197
finally:
194198
_global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub
195199
setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations
196200

197201
num_unavailable_endpoints = len(REGIONS)
198202
unavailable_endpoint_info = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.location_unavailability_info_by_endpoint
199203
assert len(unavailable_endpoint_info) == num_unavailable_endpoints
204+
# Allow both global and regional endpoint to be considered write endpoints when global write is enabled
205+
write_endpoints = {
206+
_location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1)
207+
}
208+
if use_write_global_endpoint:
209+
write_endpoints.add(self.host)
210+
211+
for endpoint, info in unavailable_endpoint_info.items():
212+
assert _location_cache.EndpointOperationType.ReadType in info["operationType"]
213+
if endpoint in write_endpoints:
214+
assert _location_cache.EndpointOperationType.WriteType in info["operationType"]
215+
else:
216+
assert _location_cache.EndpointOperationType.WriteType not in info["operationType"]
200217

201218
async def mock_health_check(self, **kwargs):
202219
await asyncio.sleep(100)

sdk/cosmos/azure-cosmos/tests/test_retry_policy.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
import azure.cosmos.exceptions as exceptions
1212
import test_config
1313
from azure.cosmos import _retry_utility, PartitionKey, documents
14-
from azure.cosmos.http_constants import HttpHeaders, StatusCodes
14+
from azure.cosmos.http_constants import HttpHeaders, StatusCodes, ResourceType
1515
from _fault_injection_transport import FaultInjectionTransport
16+
import os
17+
from azure.core.exceptions import ServiceResponseError
18+
from azure.cosmos._database_account_retry_policy import DatabaseAccountRetryPolicy
19+
from azure.cosmos._constants import _Constants
1620

1721

1822
def setup_method_with_custom_transport(
@@ -483,6 +487,90 @@ def test_patch_replace_no_retry(self):
483487
container.replace_item(item=doc['id'], body=doc)
484488
assert connection_retry_policy.counter == 0
485489

490+
491+
def test_database_account_read_retry_policy(self):
492+
os.environ['AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES'] = '5'
493+
os.environ['AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS'] = '100'
494+
max_retries = int(os.environ['AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES'])
495+
retry_after_ms = int(os.environ['AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS'])
496+
self.original_execute_function = _retry_utility.ExecuteFunction
497+
mock_execute = self.MockExecuteFunctionDBA(self.original_execute_function)
498+
_retry_utility.ExecuteFunction = mock_execute
499+
500+
try:
501+
with self.assertRaises(exceptions.CosmosHttpResponseError) as context:
502+
cosmos_client.CosmosClient(self.host, self.masterKey)
503+
# Client initialization triggers database account read
504+
505+
self.assertEqual(context.exception.status_code, 503)
506+
self.assertEqual(mock_execute.counter, max_retries + 1)
507+
policy = DatabaseAccountRetryPolicy(self.connectionPolicy)
508+
self.assertEqual(policy.retry_after_in_milliseconds, retry_after_ms)
509+
finally:
510+
del os.environ["AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES"]
511+
del os.environ["AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS"]
512+
_retry_utility.ExecuteFunction = self.original_execute_function
513+
514+
def test_database_account_read_retry_policy_defaults(self):
515+
self.original_execute_function = _retry_utility.ExecuteFunction
516+
mock_execute = self.MockExecuteFunctionDBA(self.original_execute_function)
517+
_retry_utility.ExecuteFunction = mock_execute
518+
519+
try:
520+
with self.assertRaises(exceptions.CosmosHttpResponseError) as context:
521+
cosmos_client.CosmosClient(self.host, self.masterKey)
522+
# Triggers database account read
523+
524+
self.assertEqual(context.exception.status_code, 503)
525+
self.assertEqual(
526+
mock_execute.counter,
527+
_Constants.AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES_DEFAULT + 1
528+
)
529+
policy = DatabaseAccountRetryPolicy(self.connectionPolicy)
530+
self.assertEqual(
531+
policy.retry_after_in_milliseconds,
532+
_Constants.AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS_DEFAULT
533+
)
534+
finally:
535+
_retry_utility.ExecuteFunction = self.original_execute_function
536+
537+
def test_database_account_read_retry_with_service_response_error(self):
538+
self.original_execute_function = _retry_utility.ExecuteFunction
539+
mock_execute = self.MockExecuteFunctionDBAServiceRequestError(self.original_execute_function)
540+
_retry_utility.ExecuteFunction = mock_execute
541+
542+
try:
543+
with self.assertRaises(ServiceResponseError):
544+
cosmos_client.CosmosClient(self.host, self.masterKey)
545+
# Client initialization triggers database account read
546+
547+
# Should use default retry attempts from _constants.py
548+
self.assertEqual(
549+
mock_execute.counter,
550+
_Constants.AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES_DEFAULT + 1
551+
)
552+
policy = DatabaseAccountRetryPolicy(self.connectionPolicy)
553+
self.assertEqual(
554+
policy.retry_after_in_milliseconds,
555+
_Constants.AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS_DEFAULT
556+
)
557+
finally:
558+
_retry_utility.ExecuteFunction = self.original_execute_function
559+
560+
class MockExecuteFunctionDBAServiceRequestError(object):
561+
def __init__(self, org_func):
562+
self.org_func = org_func
563+
self.counter = 0
564+
565+
def __call__(self, func, *args, **kwargs):
566+
# The second argument to the internal _request function is the RequestObject.
567+
request_object = args[1]
568+
if (request_object.operation_type == documents._OperationType.Read and
569+
request_object.resource_type == ResourceType.DatabaseAccount):
570+
self.counter += 1
571+
raise ServiceResponseError("mocked service response error")
572+
return self.org_func(func, *args, **kwargs)
573+
486574
def _MockExecuteFunction(self, function, *args, **kwargs):
487575
response = test_config.FakeResponse({HttpHeaders.RetryAfterInMilliseconds: self.retry_after_in_milliseconds})
488576
raise exceptions.CosmosHttpResponseError(
@@ -516,5 +604,21 @@ def __call__(self, func, *args, **kwargs):
516604
message="Connection was reset",
517605
response=test_config.FakeResponse({}))
518606

607+
class MockExecuteFunctionDBA(object):
608+
def __init__(self, org_func):
609+
self.org_func = org_func
610+
self.counter = 0
611+
612+
def __call__(self, func, *args, **kwargs):
613+
request_object = args[1]
614+
if (request_object.operation_type == documents._OperationType.Read and
615+
request_object.resource_type == ResourceType.DatabaseAccount):
616+
self.counter += 1
617+
raise exceptions.CosmosHttpResponseError(
618+
status_code=503,
619+
message="Service Unavailable.")
620+
return self.org_func(func, *args, **kwargs)
621+
622+
519623
if __name__ == '__main__':
520624
unittest.main()

0 commit comments

Comments
 (0)