2222"""Internal class for global endpoint manager implementation in the Azure Cosmos
2323database service.
2424"""
25-
25+ import logging
26+ import os
2627import threading
27- from typing import Tuple
28+ from concurrent .futures import ThreadPoolExecutor , as_completed
29+ from typing import Callable , Any
2830
2931from azure .core .exceptions import AzureError
3032
3739
3840
3941# pylint: disable=protected-access
40-
42+ logger = logging . getLogger ( "azure.cosmos._GlobalEndpointManager" )
4143
4244class _GlobalEndpointManager (object ): # pylint: disable=too-many-instance-attributes
4345 """
@@ -58,6 +60,9 @@ def __init__(self, client):
5860 self .refresh_lock = threading .RLock ()
5961 self .last_refresh_time = 0
6062 self ._database_account_cache = None
63+ self .startup = True
64+ self ._refresh_thread = None
65+ self .executor = ThreadPoolExecutor (max_workers = os .cpu_count ())
6166
6267 def get_refresh_time_interval_in_ms_stub (self ):
6368 return constants ._Constants .DefaultEndpointsRefreshTime
@@ -120,31 +125,58 @@ def refresh_endpoint_list(self, database_account, **kwargs):
120125 raise e
121126
122127 def _refresh_endpoint_list_private (self , database_account = None , ** kwargs ):
123- if database_account :
128+ # 1. If explicit database_account provided and not during startup, just update cache (no health check now)
129+ # 2. Else if refresh criteria met:
130+ # a. If not startup -> spawn background thread to do full database account + health checks
131+ # b. If startup -> get database account synchronously, then spawn background health checks,
132+ # then mark startup False
133+ if database_account and not self .startup :
124134 self .location_cache .perform_on_database_account_read (database_account )
125135 self .refresh_needed = False
126136 self .last_refresh_time = current_time_millis ()
127137 else :
128138 if self .location_cache .should_refresh_endpoints () or self .refresh_needed :
129139 self .refresh_needed = False
130140 self .last_refresh_time = current_time_millis ()
131- # this will perform getDatabaseAccount calls to check endpoint health
132- self ._endpoints_health_check (** kwargs )
133-
134- def _GetDatabaseAccount (self , ** kwargs ) -> Tuple [DatabaseAccount , str ]:
141+ if not self .startup :
142+ # background full refresh (database account + health checks)
143+ self ._start_background_refresh (self ._refresh_database_account_and_health , kwargs )
144+ else :
145+ self .location_cache .perform_on_database_account_read (database_account )
146+ self ._start_background_refresh (self ._endpoints_health_check , kwargs )
147+ self .startup = False
148+
149+ def _start_background_refresh (self , target : Callable [..., None ], kwargs : dict [str , Any ]):
150+ """Starts a daemon thread to run the given target if one is not already active.
151+ :param Callable target: The function to run in the background thread.
152+ :param dict kwargs: The keyword arguments to pass to the target function.
153+ """
154+ if not (self ._refresh_thread and self ._refresh_thread .is_alive ()):
155+ def runner ():
156+ try :
157+ target (** kwargs )
158+ except Exception as exception : #pylint: disable=broad-exception-caught
159+ # background failures should not crash main thread
160+ # Intentionally swallow to avoid affecting foreground; logging could be added.
161+ logger .error ("Health check task failed: %s" , exception , exc_info = True )
162+ t = threading .Thread (target = runner , name = "cosmos-endpoint-refresh" , daemon = True )
163+ self ._refresh_thread = t
164+ t .start ()
165+
166+ def _GetDatabaseAccount (self , ** kwargs ) -> DatabaseAccount :
135167 """Gets the database account.
136168
137169 First tries by using the default endpoint, and if that doesn't work,
138170 use the endpoints for the preferred locations in the order they are
139171 specified, to get the database account.
140172 :returns: A `DatabaseAccount` instance representing the Cosmos DB Database Account
141173 and the endpoint that was used for the request.
142- :rtype: tuple of ( ~azure.cosmos.DatabaseAccount, str)
174+ :rtype: ~azure.cosmos.DatabaseAccount
143175 """
144176 try :
145177 database_account = self ._GetDatabaseAccountStub (self .DefaultEndpoint , ** kwargs )
146178 self ._database_account_cache = database_account
147- return database_account , self . DefaultEndpoint
179+ return database_account
148180 # If for any reason(non-globaldb related), we are not able to get the database
149181 # account from the above call to GetDatabaseAccount, we would try to get this
150182 # information from any of the preferred locations that the user might have
@@ -157,52 +189,34 @@ def _GetDatabaseAccount(self, **kwargs) -> Tuple[DatabaseAccount, str]:
157189 try :
158190 database_account = self ._GetDatabaseAccountStub (locational_endpoint , ** kwargs )
159191 self ._database_account_cache = database_account
160- self .location_cache .mark_endpoint_available (locational_endpoint )
161- return database_account , locational_endpoint
192+ return database_account
162193 except (exceptions .CosmosHttpResponseError , AzureError ):
163194 self ._mark_endpoint_unavailable (locational_endpoint , "_GetDatabaseAccount" )
164195 raise
165196
166197 def _endpoints_health_check (self , ** kwargs ):
167- """Gets the database account for each endpoint.
168-
169- Validating if the endpoint is healthy else marking it as unavailable.
170- """
171- database_account , attempted_endpoint = self ._GetDatabaseAccount (** kwargs )
172- self .location_cache .perform_on_database_account_read (database_account )
173- # get all the regional routing contexts to check
198+ """Performs concurrent health checks for each endpoint (background-safe)."""
174199 endpoints = self .location_cache .endpoints_to_health_check ()
175- success_count = 0
176- for endpoint in endpoints :
177- if endpoint != attempted_endpoint :
178- if success_count >= 4 :
179- break
180- # save current dba timeouts
181- previous_dba_read_timeout = self .client .connection_policy .DBAReadTimeout
182- previous_dba_connection_timeout = self .client .connection_policy .DBAConnectionTimeout
183- try :
184- if (endpoint in
185- self .location_cache .location_unavailability_info_by_endpoint ):
186- # if the endpoint is unavailable, we need to lower the timeouts to be more aggressive in the
187- # health check. This helps reduce the time the health check is blocking all requests.
188- self .client .connection_policy ._override_dba_timeouts (constants ._Constants
189- .UnavailableEndpointDBATimeouts ,
190- constants ._Constants
191- .UnavailableEndpointDBATimeouts )
192- self .client ._GetDatabaseAccountCheck (endpoint , ** kwargs )
193- else :
194- self .client ._GetDatabaseAccountCheck (endpoint , ** kwargs )
195- success_count += 1
196- self .location_cache .mark_endpoint_available (endpoint )
197- except (exceptions .CosmosHttpResponseError , AzureError ):
198- self ._mark_endpoint_unavailable (endpoint , "_endpoints_health_check" )
199200
200- finally :
201- # after the health check for that endpoint setting the timeouts back to their original values
202- self .client .connection_policy ._override_dba_timeouts (previous_dba_read_timeout ,
203- previous_dba_connection_timeout )
201+ def _health_check (endpoint : str ):
202+ try :
203+ self .client .health_check (endpoint , ** kwargs )
204+ self .location_cache .mark_endpoint_available (endpoint )
205+ except (exceptions .CosmosHttpResponseError , AzureError ):
206+ self ._mark_endpoint_unavailable (endpoint , "_endpoints_health_check" )
207+
208+ futures = [self .executor .submit (_health_check , ep ) for ep in endpoints ]
209+ for f in as_completed (futures ):
210+ # propagate unexpected exceptions (should be none besides those swallowed in health check)
211+ _ = f .result ()
212+ # After all probes, update cache once
204213 self .location_cache .update_location_cache ()
205214
215+ def _refresh_database_account_and_health (self , ** kwargs ):
216+ database_account = self ._GetDatabaseAccount (** kwargs )
217+ self .location_cache .perform_on_database_account_read (database_account )
218+ self ._endpoints_health_check (** kwargs )
219+
206220 def _GetDatabaseAccountStub (self , endpoint , ** kwargs ):
207221 """Stub for getting database account from the client.
208222 This can be used for mocking purposes as well.
@@ -211,21 +225,4 @@ def _GetDatabaseAccountStub(self, endpoint, **kwargs):
211225 :returns: A `DatabaseAccount` instance representing the Cosmos DB Database Account.
212226 :rtype: ~azure.cosmos.DatabaseAccount
213227 """
214- if endpoint in self .location_cache .location_unavailability_info_by_endpoint :
215- previous_dba_read_timeout = self .client .connection_policy .DBAReadTimeout
216- previous_dba_connection_timeout = self .client .connection_policy .DBAConnectionTimeout
217- try :
218- # if the endpoint is unavailable, we need to lower the timeouts to be more aggressive in the
219- # health check. This helps reduce the time the health check is blocking all requests.
220- self .client .connection_policy ._override_dba_timeouts (constants ._Constants
221- .UnavailableEndpointDBATimeouts ,
222- constants ._Constants
223- .UnavailableEndpointDBATimeouts )
224- database_account = self .client .GetDatabaseAccount (endpoint , ** kwargs )
225- finally :
226- # after the health check for that endpoint setting the timeouts back to their original values
227- self .client .connection_policy ._override_dba_timeouts (previous_dba_read_timeout ,
228- previous_dba_connection_timeout )
229- else :
230- database_account = self .client .GetDatabaseAccount (endpoint , ** kwargs )
231- return database_account
228+ return self .client .GetDatabaseAccount (endpoint , ** kwargs )
0 commit comments