@@ -56,8 +56,6 @@ class TestHealthCheckAsync:
5656 connectionPolicy = test_config .TestConfig .connectionPolicy
5757 TEST_DATABASE_ID = test_config .TestConfig .TEST_DATABASE_ID
5858 TEST_CONTAINER_SINGLE_PARTITION_ID = test_config .TestConfig .TEST_SINGLE_PARTITION_CONTAINER_ID
59- # health check in all these tests should check the endpoints for the first two write regions and the first two read regions
60- # without checking the same endpoint twice
6159
6260 @pytest .mark .parametrize ("preferred_location, use_write_global_endpoint, use_read_global_endpoint" , health_check ())
6361 async def test_health_check_success_startup_async (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
@@ -79,10 +77,7 @@ async def test_health_check_success_startup_async(self, setup, preferred_locatio
7977 expected_regional_routing_context = []
8078
8179 locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , REGION_1 )
82- if use_read_global_endpoint :
83- assert mock_get_database_account_check .counter == 1
84- else :
85- assert mock_get_database_account_check .counter == 2
80+ assert mock_get_database_account_check .counter == 2
8681 endpoint = self .host if use_read_global_endpoint else locational_endpoint
8782 expected_regional_routing_context .append (RegionalRoutingContext (endpoint , endpoint ))
8883 locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , REGION_2 )
@@ -107,12 +102,8 @@ async def test_health_check_failure_startup_async(self, setup, preferred_locatio
107102 _global_endpoint_manager_async ._GlobalEndpointManager ._GetDatabaseAccountStub = self .original_getDatabaseAccountStub
108103 expected_endpoints = []
109104
110- if not use_read_global_endpoint :
111- for region in REGIONS :
112- locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , region )
113- expected_endpoints .append (locational_endpoint )
114- else :
115- locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , REGION_2 )
105+ for region in REGIONS :
106+ locational_endpoint = _location_cache .LocationCache .GetLocationalEndpoint (self .host , region )
116107 expected_endpoints .append (locational_endpoint )
117108
118109 unavailable_endpoint_info = client .client_connection ._global_endpoint_manager .location_cache .location_unavailability_info_by_endpoint
@@ -149,7 +140,7 @@ async def test_health_check_background_fail(self, setup):
149140 _global_endpoint_manager_async ._GlobalEndpointManager ._endpoints_health_check = self .original_health_check
150141
151142 @pytest .mark .parametrize ("preferred_location, use_write_global_endpoint, use_read_global_endpoint" , health_check ())
152- async def test_health_check_success (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
143+ async def test_health_check_success_async (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
153144 # checks the background health check works as expected when all endpoints healthy
154145 self .original_getDatabaseAccountStub = _global_endpoint_manager_async ._GlobalEndpointManager ._GetDatabaseAccountStub
155146 self .original_getDatabaseAccountCheck = _cosmos_client_connection_async .CosmosClientConnection ._GetDatabaseAccountCheck
@@ -183,7 +174,7 @@ async def test_health_check_success(self, setup, preferred_location, use_write_g
183174
184175
185176 @pytest .mark .parametrize ("preferred_location, use_write_global_endpoint, use_read_global_endpoint" , health_check ())
186- async def test_health_check_failure (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
177+ async def test_health_check_failure_async (self , setup , preferred_location , use_write_global_endpoint , use_read_global_endpoint ):
187178 # checks the background health check works as expected when all endpoints unhealthy - it should mark the endpoints unavailable
188179 setup [COLLECTION ].client_connection ._global_endpoint_manager .location_cache .location_unavailability_info_by_endpoint .clear ()
189180 self .original_getDatabaseAccountStub = _global_endpoint_manager_async ._GlobalEndpointManager ._GetDatabaseAccountStub
@@ -198,16 +189,12 @@ async def test_health_check_failure(self, setup, preferred_location, use_write_g
198189 for i in range (2 ):
199190 await setup [COLLECTION ].create_item (body = {'id' : 'item' + str (uuid .uuid4 ()), 'pk' : 'pk' })
200191 # wait for background task to finish
201- await asyncio .sleep (1 )
192+ await asyncio .sleep (2 )
202193 finally :
203194 _global_endpoint_manager_async ._GlobalEndpointManager ._GetDatabaseAccountStub = self .original_getDatabaseAccountStub
204195 setup [COLLECTION ].client_connection .connection_policy .PreferredLocations = self .original_preferred_locations
205196
206- if not use_write_global_endpoint :
207- num_unavailable_endpoints = len (REGIONS )
208- else :
209- num_unavailable_endpoints = 1
210-
197+ num_unavailable_endpoints = len (REGIONS )
211198 unavailable_endpoint_info = setup [COLLECTION ].client_connection ._global_endpoint_manager .location_cache .location_unavailability_info_by_endpoint
212199 assert len (unavailable_endpoint_info ) == num_unavailable_endpoints
213200
0 commit comments