|
11 | 11 | import azure.cosmos.exceptions as exceptions
|
12 | 12 | import test_config
|
13 | 13 | from azure.cosmos import _retry_utility, PartitionKey, documents
|
14 |
| -from azure.cosmos.http_constants import HttpHeaders, StatusCodes |
| 14 | +from azure.cosmos.http_constants import HttpHeaders, StatusCodes, ResourceType |
15 | 15 | from _fault_injection_transport import FaultInjectionTransport
|
| 16 | +import os |
| 17 | +from azure.core.exceptions import ServiceResponseError |
| 18 | +from azure.cosmos._database_account_retry_policy import DatabaseAccountRetryPolicy |
| 19 | +from azure.cosmos._constants import _Constants |
16 | 20 |
|
17 | 21 |
|
18 | 22 | def setup_method_with_custom_transport(
|
@@ -483,6 +487,90 @@ def test_patch_replace_no_retry(self):
|
483 | 487 | container.replace_item(item=doc['id'], body=doc)
|
484 | 488 | assert connection_retry_policy.counter == 0
|
485 | 489 |
|
| 490 | + |
| 491 | + def test_database_account_read_retry_policy(self): |
| 492 | + os.environ['AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES'] = '5' |
| 493 | + os.environ['AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS'] = '100' |
| 494 | + max_retries = int(os.environ['AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES']) |
| 495 | + retry_after_ms = int(os.environ['AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS']) |
| 496 | + self.original_execute_function = _retry_utility.ExecuteFunction |
| 497 | + mock_execute = self.MockExecuteFunctionDBA(self.original_execute_function) |
| 498 | + _retry_utility.ExecuteFunction = mock_execute |
| 499 | + |
| 500 | + try: |
| 501 | + with self.assertRaises(exceptions.CosmosHttpResponseError) as context: |
| 502 | + cosmos_client.CosmosClient(self.host, self.masterKey) |
| 503 | + # Client initialization triggers database account read |
| 504 | + |
| 505 | + self.assertEqual(context.exception.status_code, 503) |
| 506 | + self.assertEqual(mock_execute.counter, max_retries + 1) |
| 507 | + policy = DatabaseAccountRetryPolicy(self.connectionPolicy) |
| 508 | + self.assertEqual(policy.retry_after_in_milliseconds, retry_after_ms) |
| 509 | + finally: |
| 510 | + del os.environ["AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES"] |
| 511 | + del os.environ["AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS"] |
| 512 | + _retry_utility.ExecuteFunction = self.original_execute_function |
| 513 | + |
| 514 | + def test_database_account_read_retry_policy_defaults(self): |
| 515 | + self.original_execute_function = _retry_utility.ExecuteFunction |
| 516 | + mock_execute = self.MockExecuteFunctionDBA(self.original_execute_function) |
| 517 | + _retry_utility.ExecuteFunction = mock_execute |
| 518 | + |
| 519 | + try: |
| 520 | + with self.assertRaises(exceptions.CosmosHttpResponseError) as context: |
| 521 | + cosmos_client.CosmosClient(self.host, self.masterKey) |
| 522 | + # Triggers database account read |
| 523 | + |
| 524 | + self.assertEqual(context.exception.status_code, 503) |
| 525 | + self.assertEqual( |
| 526 | + mock_execute.counter, |
| 527 | + _Constants.AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES_DEFAULT + 1 |
| 528 | + ) |
| 529 | + policy = DatabaseAccountRetryPolicy(self.connectionPolicy) |
| 530 | + self.assertEqual( |
| 531 | + policy.retry_after_in_milliseconds, |
| 532 | + _Constants.AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS_DEFAULT |
| 533 | + ) |
| 534 | + finally: |
| 535 | + _retry_utility.ExecuteFunction = self.original_execute_function |
| 536 | + |
| 537 | + def test_database_account_read_retry_with_service_response_error(self): |
| 538 | + self.original_execute_function = _retry_utility.ExecuteFunction |
| 539 | + mock_execute = self.MockExecuteFunctionDBAServiceRequestError(self.original_execute_function) |
| 540 | + _retry_utility.ExecuteFunction = mock_execute |
| 541 | + |
| 542 | + try: |
| 543 | + with self.assertRaises(ServiceResponseError): |
| 544 | + cosmos_client.CosmosClient(self.host, self.masterKey) |
| 545 | + # Client initialization triggers database account read |
| 546 | + |
| 547 | + # Should use default retry attempts from _constants.py |
| 548 | + self.assertEqual( |
| 549 | + mock_execute.counter, |
| 550 | + _Constants.AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES_DEFAULT + 1 |
| 551 | + ) |
| 552 | + policy = DatabaseAccountRetryPolicy(self.connectionPolicy) |
| 553 | + self.assertEqual( |
| 554 | + policy.retry_after_in_milliseconds, |
| 555 | + _Constants.AZURE_COSMOS_HEALTH_CHECK_RETRY_AFTER_MS_DEFAULT |
| 556 | + ) |
| 557 | + finally: |
| 558 | + _retry_utility.ExecuteFunction = self.original_execute_function |
| 559 | + |
| 560 | + class MockExecuteFunctionDBAServiceRequestError(object): |
| 561 | + def __init__(self, org_func): |
| 562 | + self.org_func = org_func |
| 563 | + self.counter = 0 |
| 564 | + |
| 565 | + def __call__(self, func, *args, **kwargs): |
| 566 | + # The second argument to the internal _request function is the RequestObject. |
| 567 | + request_object = args[1] |
| 568 | + if (request_object.operation_type == documents._OperationType.Read and |
| 569 | + request_object.resource_type == ResourceType.DatabaseAccount): |
| 570 | + self.counter += 1 |
| 571 | + raise ServiceResponseError("mocked service response error") |
| 572 | + return self.org_func(func, *args, **kwargs) |
| 573 | + |
486 | 574 | def _MockExecuteFunction(self, function, *args, **kwargs):
|
487 | 575 | response = test_config.FakeResponse({HttpHeaders.RetryAfterInMilliseconds: self.retry_after_in_milliseconds})
|
488 | 576 | raise exceptions.CosmosHttpResponseError(
|
@@ -516,5 +604,21 @@ def __call__(self, func, *args, **kwargs):
|
516 | 604 | message="Connection was reset",
|
517 | 605 | response=test_config.FakeResponse({}))
|
518 | 606 |
|
| 607 | + class MockExecuteFunctionDBA(object): |
| 608 | + def __init__(self, org_func): |
| 609 | + self.org_func = org_func |
| 610 | + self.counter = 0 |
| 611 | + |
| 612 | + def __call__(self, func, *args, **kwargs): |
| 613 | + request_object = args[1] |
| 614 | + if (request_object.operation_type == documents._OperationType.Read and |
| 615 | + request_object.resource_type == ResourceType.DatabaseAccount): |
| 616 | + self.counter += 1 |
| 617 | + raise exceptions.CosmosHttpResponseError( |
| 618 | + status_code=503, |
| 619 | + message="Service Unavailable.") |
| 620 | + return self.org_func(func, *args, **kwargs) |
| 621 | + |
| 622 | + |
519 | 623 | if __name__ == '__main__':
|
520 | 624 | unittest.main()
|
0 commit comments