2525import collections
2626import logging
2727import time
28- from typing import Set
28+ from typing import Set , Mapping , List
2929from urllib .parse import urlparse
3030
3131from . import documents
3232from . import http_constants
3333from .documents import _OperationType
34+ from ._request_object import RequestObject
3435
3536# pylint: disable=protected-access
3637
@@ -113,7 +114,10 @@ def get_endpoints_by_location(new_locations,
113114 except Exception as e :
114115 raise e
115116
116- return endpoints_by_location , parsed_locations
117+ # Also store a hash map of endpoints for each location
118+ locations_by_endpoints = {value .get_primary (): key for key , value in endpoints_by_location .items ()}
119+
120+ return endpoints_by_location , locations_by_endpoints , parsed_locations
117121
118122def add_endpoint_if_preferred (endpoint : str , preferred_endpoints : Set [str ], endpoints : Set [str ]) -> bool :
119123 if endpoint in preferred_endpoints :
@@ -150,31 +154,44 @@ def _get_health_check_endpoints(
150154
151155 return endpoints
152156
157+ def _get_applicable_regional_routing_contexts (regional_routing_contexts : List [RegionalRoutingContext ],
158+ location_name_by_endpoint : Mapping [str , str ],
159+ fall_back_regional_routing_context : RegionalRoutingContext ,
160+ exclude_location_list : List [str ]) -> List [RegionalRoutingContext ]:
161+ # filter endpoints by excluded locations
162+ applicable_regional_routing_contexts = []
163+ for regional_routing_context in regional_routing_contexts :
164+ if location_name_by_endpoint .get (regional_routing_context .get_primary ()) not in exclude_location_list :
165+ applicable_regional_routing_contexts .append (regional_routing_context )
166+
167+ # if endpoint is empty add fallback endpoint
168+ if not applicable_regional_routing_contexts :
169+ applicable_regional_routing_contexts .append (fall_back_regional_routing_context )
170+
171+ return applicable_regional_routing_contexts
153172
154173class LocationCache (object ): # pylint: disable=too-many-public-methods,too-many-instance-attributes
155174 def current_time_millis (self ):
156175 return int (round (time .time () * 1000 ))
157176
158177 def __init__ (
159178 self ,
160- preferred_locations ,
161179 default_endpoint ,
162- enable_endpoint_discovery ,
163- use_multiple_write_locations ,
180+ connection_policy ,
164181 ):
165- self .preferred_locations = preferred_locations
166182 self .default_regional_routing_context = RegionalRoutingContext (default_endpoint , default_endpoint )
167- self .enable_endpoint_discovery = enable_endpoint_discovery
168- self .use_multiple_write_locations = use_multiple_write_locations
169183 self .enable_multiple_writable_locations = False
170184 self .write_regional_routing_contexts = [self .default_regional_routing_context ]
171185 self .read_regional_routing_contexts = [self .default_regional_routing_context ]
172186 self .location_unavailability_info_by_endpoint = {}
173187 self .last_cache_update_time_stamp = 0
174188 self .account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
175189 self .account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
190+ self .account_locations_by_read_regional_routing_context = {} # pylint: disable=name-too-long
191+ self .account_locations_by_write_regional_routing_context = {} # pylint: disable=name-too-long
176192 self .account_write_locations = []
177193 self .account_read_locations = []
194+ self .connection_policy = connection_policy
178195
179196 def get_write_regional_routing_contexts (self ):
180197 return self .write_regional_routing_contexts
@@ -207,6 +224,44 @@ def get_ordered_write_locations(self):
207224 def get_ordered_read_locations (self ):
208225 return self .account_read_locations
209226
227+ def _get_configured_excluded_locations (self , request : RequestObject ) -> List [str ]:
228+ # If excluded locations were configured on request, use request level excluded locations.
229+ excluded_locations = request .excluded_locations
230+ if excluded_locations is None :
231+ # If excluded locations were only configured on client(connection_policy), use client level
232+ excluded_locations = self .connection_policy .ExcludedLocations
233+ return excluded_locations
234+
235+ def _get_applicable_read_regional_routing_contexts (self , request : RequestObject ) -> List [RegionalRoutingContext ]:
236+ # Get configured excluded locations
237+ excluded_locations = self ._get_configured_excluded_locations (request )
238+
239+ # If excluded locations were configured, return filtered regional endpoints by excluded locations.
240+ if excluded_locations :
241+ return _get_applicable_regional_routing_contexts (
242+ self .get_read_regional_routing_contexts (),
243+ self .account_locations_by_read_regional_routing_context ,
244+ self .get_write_regional_routing_contexts ()[0 ],
245+ excluded_locations )
246+
247+ # Else, return all regional endpoints
248+ return self .get_read_regional_routing_contexts ()
249+
250+ def _get_applicable_write_regional_routing_contexts (self , request : RequestObject ) -> List [RegionalRoutingContext ]:
251+ # Get configured excluded locations
252+ excluded_locations = self ._get_configured_excluded_locations (request )
253+
254+ # If excluded locations were configured, return filtered regional endpoints by excluded locations.
255+ if excluded_locations :
256+ return _get_applicable_regional_routing_contexts (
257+ self .get_write_regional_routing_contexts (),
258+ self .account_locations_by_write_regional_routing_context ,
259+ self .default_regional_routing_context ,
260+ excluded_locations )
261+
262+ # Else, return all regional endpoints
263+ return self .get_write_regional_routing_contexts ()
264+
210265 def resolve_service_endpoint (self , request ):
211266 if request .location_endpoint_to_route :
212267 return request .location_endpoint_to_route
@@ -227,7 +282,7 @@ def resolve_service_endpoint(self, request):
227282 # For non-document resource types in case of client can use multiple write locations
228283 # or when client cannot use multiple write locations, flip-flop between the
229284 # first and the second writable region in DatabaseAccount (for manual failover)
230- if self .enable_endpoint_discovery and self .account_write_locations :
285+ if self .connection_policy . EnableEndpointDiscovery and self .account_write_locations :
231286 location_index = min (location_index % 2 , len (self .account_write_locations ) - 1 )
232287 write_location = self .account_write_locations [location_index ]
233288 if (self .account_write_regional_routing_contexts_by_location
@@ -247,9 +302,9 @@ def resolve_service_endpoint(self, request):
247302 return self .default_regional_routing_context .get_primary ()
248303
249304 regional_routing_contexts = (
250- self .get_write_regional_routing_contexts ( )
305+ self ._get_applicable_write_regional_routing_contexts ( request )
251306 if documents ._OperationType .IsWriteOperation (request .operation_type )
252- else self .get_read_regional_routing_contexts ( )
307+ else self ._get_applicable_read_regional_routing_contexts ( request )
253308 )
254309 regional_routing_context = regional_routing_contexts [location_index % len (regional_routing_contexts )]
255310 if (
@@ -263,12 +318,14 @@ def resolve_service_endpoint(self, request):
263318 return regional_routing_context .get_primary ()
264319
265320 def should_refresh_endpoints (self ): # pylint: disable=too-many-return-statements
266- most_preferred_location = self .preferred_locations [0 ] if self .preferred_locations else None
321+ most_preferred_location = self .connection_policy .PreferredLocations [0 ] \
322+ if self .connection_policy .PreferredLocations else None
267323
268324 # we should schedule refresh in background if we are unable to target the user's most preferredLocation.
269- if self .enable_endpoint_discovery :
325+ if self .connection_policy . EnableEndpointDiscovery :
270326
271- should_refresh = self .use_multiple_write_locations and not self .enable_multiple_writable_locations
327+ should_refresh = (self .connection_policy .UseMultipleWriteLocations
328+ and not self .enable_multiple_writable_locations )
272329
273330 if (most_preferred_location and most_preferred_location in
274331 self .account_read_regional_routing_contexts_by_location ):
@@ -358,25 +415,27 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
358415 if enable_multiple_writable_locations :
359416 self .enable_multiple_writable_locations = enable_multiple_writable_locations
360417
361- if self .enable_endpoint_discovery :
418+ if self .connection_policy . EnableEndpointDiscovery :
362419 if read_locations :
363420 (self .account_read_regional_routing_contexts_by_location ,
421+ self .account_locations_by_read_regional_routing_context ,
364422 self .account_read_locations ) = get_endpoints_by_location (
365423 read_locations ,
366424 self .account_read_regional_routing_contexts_by_location ,
367425 self .default_regional_routing_context ,
368426 False ,
369- self .use_multiple_write_locations
427+ self .connection_policy . UseMultipleWriteLocations
370428 )
371429
372430 if write_locations :
373431 (self .account_write_regional_routing_contexts_by_location ,
432+ self .account_locations_by_write_regional_routing_context ,
374433 self .account_write_locations ) = get_endpoints_by_location (
375434 write_locations ,
376435 self .account_write_regional_routing_contexts_by_location ,
377436 self .default_regional_routing_context ,
378437 True ,
379- self .use_multiple_write_locations
438+ self .connection_policy . UseMultipleWriteLocations
380439 )
381440
382441 self .write_regional_routing_contexts = self .get_preferred_regional_routing_contexts (
@@ -399,18 +458,18 @@ def get_preferred_regional_routing_contexts(
399458 regional_endpoints = []
400459 # if enableEndpointDiscovery is false, we always use the defaultEndpoint that
401460 # user passed in during documentClient init
402- if self .enable_endpoint_discovery and endpoints_by_location : # pylint: disable=too-many-nested-blocks
461+ if self .connection_policy . EnableEndpointDiscovery and endpoints_by_location : # pylint: disable=too-many-nested-blocks
403462 if (
404463 self .can_use_multiple_write_locations ()
405464 or expected_available_operation == EndpointOperationType .ReadType
406465 ):
407466 unavailable_endpoints = []
408- if self .preferred_locations :
467+ if self .connection_policy . PreferredLocations :
409468 # When client can not use multiple write locations, preferred locations
410469 # list should only be used determining read endpoints order. If client
411470 # can use multiple write locations, preferred locations list should be
412471 # used for determining both read and write endpoints order.
413- for location in self .preferred_locations :
472+ for location in self .connection_policy . PreferredLocations :
414473 regional_endpoint = endpoints_by_location [location ] if location in endpoints_by_location \
415474 else None
416475 if regional_endpoint :
@@ -436,11 +495,12 @@ def get_preferred_regional_routing_contexts(
436495 return regional_endpoints
437496
438497 def can_use_multiple_write_locations (self ):
439- return self .use_multiple_write_locations and self .enable_multiple_writable_locations
498+ return self .connection_policy . UseMultipleWriteLocations and self .enable_multiple_writable_locations
440499
441500 def can_use_multiple_write_locations_for_request (self , request ): # pylint: disable=name-too-long
442501 return self .can_use_multiple_write_locations () and (
443502 request .resource_type == http_constants .ResourceType .Document
503+ or request .resource_type == http_constants .ResourceType .PartitionKey
444504 or (
445505 request .resource_type == http_constants .ResourceType .StoredProcedure
446506 and request .operation_type == documents ._OperationType .ExecuteJavaScript
0 commit comments