2424"""
2525import collections
2626import logging
27- from typing import Set , Mapping , List
27+ from typing import Set , Mapping , OrderedDict , Dict
28+ from typing import List
2829from urllib .parse import urlparse
2930
3031from . import documents , _base as base
3132from .http_constants import ResourceType
32- from .documents import _OperationType
33+ from .documents import _OperationType , ConnectionPolicy
3334from ._request_object import RequestObject
3435
3536# pylint: disable=protected-access
@@ -43,8 +44,8 @@ class EndpointOperationType(object):
4344
4445class RegionalRoutingContext (object ):
4546 def __init__ (self , primary_endpoint : str , alternate_endpoint : str ):
46- self .primary_endpoint = primary_endpoint
47- self .alternate_endpoint = alternate_endpoint
47+ self .primary_endpoint : str = primary_endpoint
48+ self .alternate_endpoint : str = alternate_endpoint
4849
4950 def set_primary (self , endpoint : str ):
5051 self .primary_endpoint = endpoint
@@ -65,13 +66,13 @@ def __eq__(self, other):
6566 def __str__ (self ):
6667 return "Primary: " + self .primary_endpoint + ", Alternate: " + self .alternate_endpoint
6768
68- def get_endpoints_by_location (new_locations ,
69- old_endpoints_by_location ,
70- default_regional_endpoint ,
71- writes ,
72- use_multiple_write_locations ):
69+ def get_endpoints_by_location (new_locations : List [ Dict [ str , str ]] ,
70+ old_regional_routing_contexts_by_location : Dict [ str , RegionalRoutingContext ] ,
71+ default_regional_endpoint : RegionalRoutingContext ,
72+ writes : bool ,
73+ use_multiple_write_locations : bool ):
7374 # construct from previous object
74- endpoints_by_location = collections .OrderedDict ()
75+ regional_routing_context_by_location : OrderedDict [ str , RegionalRoutingContext ] = collections .OrderedDict ()
7576 parsed_locations = []
7677
7778
@@ -86,8 +87,8 @@ def get_endpoints_by_location(new_locations,
8687 parsed_locations .append (new_location ["name" ])
8788 if not writes or use_multiple_write_locations :
8889 regional_object = RegionalRoutingContext (region_uri , region_uri )
89- elif new_location ["name" ] in old_endpoints_by_location :
90- regional_object = old_endpoints_by_location [new_location ["name" ]]
90+ elif new_location ["name" ] in old_regional_routing_contexts_by_location :
91+ regional_object = old_regional_routing_contexts_by_location [new_location ["name" ]]
9192 current = regional_object .get_primary ()
9293 # swap the previous with current and current with new region_uri received from the gateway
9394 if current != region_uri :
@@ -108,15 +109,14 @@ def get_endpoints_by_location(new_locations,
108109 default_regional_endpoint .get_primary (),
109110 new_location ["name" ])
110111 regional_object .set_alternate (constructed_region_uri )
111- # pass in object with region uri , last known good, curr etc
112- endpoints_by_location .update ({new_location ["name" ]: regional_object })
112+ regional_routing_context_by_location .update ({new_location ["name" ]: regional_object })
113113 except Exception as e :
114114 raise e
115115
116116 # Also store a hash map of endpoints for each location
117- locations_by_endpoints = {value .get_primary (): key for key , value in endpoints_by_location .items ()}
117+ locations_by_endpoints = {value .get_primary (): key for key , value in regional_routing_context_by_location .items ()}
118118
119- return endpoints_by_location , locations_by_endpoints , parsed_locations
119+ return regional_routing_context_by_location , locations_by_endpoints , parsed_locations
120120
121121def _get_health_check_endpoints (regional_routing_contexts ) -> Set [str ]:
122122 # should use the endpoints in the order returned from gateway and only the ones specified in preferred locations
@@ -154,22 +154,24 @@ class LocationCache(object): # pylint: disable=too-many-public-methods,too-many
154154
155155 def __init__ (
156156 self ,
157- default_endpoint ,
158- connection_policy ,
157+ default_endpoint : str ,
158+ connection_policy : ConnectionPolicy ,
159159 ):
160- self .default_regional_routing_context = RegionalRoutingContext (default_endpoint , default_endpoint )
161- self .enable_multiple_writable_locations = False
162- self .write_regional_routing_contexts = [self .default_regional_routing_context ]
163- self .read_regional_routing_contexts = [self .default_regional_routing_context ]
164- self .location_unavailability_info_by_endpoint = {}
165- self .last_cache_update_time_stamp = 0
166- self .account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
167- self .account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
168- self .account_locations_by_read_endpoints = {} # pylint: disable=name-too-long
169- self .account_locations_by_write_endpoints = {} # pylint: disable=name-too-long
170- self .account_write_locations = []
171- self .account_read_locations = []
172- self .connection_policy = connection_policy
160+ self .default_regional_routing_context : RegionalRoutingContext = RegionalRoutingContext (default_endpoint ,
161+ default_endpoint )
162+ self .effective_preferred_locations : List [str ] = []
163+ self .enable_multiple_writable_locations : bool = False
164+ self .write_regional_routing_contexts : List [RegionalRoutingContext ] = [self .default_regional_routing_context ]
165+ self .read_regional_routing_contexts : List [RegionalRoutingContext ] = [self .default_regional_routing_context ]
166+ self .location_unavailability_info_by_endpoint : Dict [str , Dict [str , Set [EndpointOperationType ]]] = {}
167+ self .last_cache_update_time_stamp : int = 0
168+ self .account_read_regional_routing_contexts_by_location : Dict [str , RegionalRoutingContext ] = {} # pylint: disable=name-too-long
169+ self .account_write_regional_routing_contexts_by_location : Dict [str , RegionalRoutingContext ] = {} # pylint: disable=name-too-long
170+ self .account_locations_by_read_endpoints : Dict [str , str ] = {} # pylint: disable=name-too-long
171+ self .account_locations_by_write_endpoints : Dict [str , str ] = {} # pylint: disable=name-too-long
172+ self .account_write_locations : List [str ] = []
173+ self .account_read_locations : List [str ] = []
174+ self .connection_policy : ConnectionPolicy = connection_policy
173175
174176 def get_write_regional_routing_contexts (self ):
175177 return self .write_regional_routing_contexts
@@ -310,8 +312,7 @@ def resolve_service_endpoint(self, request):
310312 return regional_routing_context .get_primary ()
311313
312314 def should_refresh_endpoints (self ): # pylint: disable=too-many-return-statements
313- most_preferred_location = self .connection_policy .PreferredLocations [0 ] \
314- if self .connection_policy .PreferredLocations else None
315+ most_preferred_location = self .effective_preferred_locations [0 ] if self .effective_preferred_locations else None
315316
316317 # we should schedule refresh in background if we are unable to target the user's most preferredLocation.
317318 if self .connection_policy .EnableEndpointDiscovery :
@@ -379,7 +380,7 @@ def is_endpoint_unavailable_internal(self, endpoint: str, expected_available_ope
379380 return True
380381
381382 def mark_endpoint_unavailable (
382- self , unavailable_endpoint : str , unavailable_operation_type : str , refresh_cache : bool ):
383+ self , unavailable_endpoint : str , unavailable_operation_type : EndpointOperationType , refresh_cache : bool ):
383384 logger .warning ("Marking %s unavailable for %s " ,
384385 unavailable_endpoint ,
385386 unavailable_operation_type )
@@ -431,6 +432,15 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
431432 self .connection_policy .UseMultipleWriteLocations
432433 )
433434
435+ # if preferred locations is empty and the default endpoint is a global endpoint,
436+ # we should use the read locations from gateway as effective preferred locations
437+ if self .connection_policy .PreferredLocations :
438+ self .effective_preferred_locations = self .connection_policy .PreferredLocations
439+ elif self .is_default_endpoint_regional ():
440+ self .effective_preferred_locations = []
441+ elif not self .effective_preferred_locations :
442+ self .effective_preferred_locations = self .account_read_locations
443+
434444 self .write_regional_routing_contexts = self .get_preferred_regional_routing_contexts (
435445 self .account_write_regional_routing_contexts_by_location ,
436446 self .account_write_locations ,
@@ -456,12 +466,12 @@ def get_preferred_regional_routing_contexts(
456466 or expected_available_operation == EndpointOperationType .ReadType
457467 ):
458468 unavailable_endpoints = []
459- if self .connection_policy . PreferredLocations :
469+ if self .effective_preferred_locations :
460470 # When client can not use multiple write locations, preferred locations
461471 # list should only be used determining read endpoints order. If client
462472 # can use multiple write locations, preferred locations list should be
463473 # used for determining both read and write endpoints order.
464- for location in self .connection_policy . PreferredLocations :
474+ for location in self .effective_preferred_locations :
465475 regional_endpoint = endpoints_by_location [location ] if location in endpoints_by_location \
466476 else None
467477 if regional_endpoint :
@@ -486,6 +496,13 @@ def get_preferred_regional_routing_contexts(
486496
487497 return regional_endpoints
488498
499+ # if the endpoint is returned from the gateway in the account topology, it is a regional endpoint
500+ def is_default_endpoint_regional (self ) -> bool :
501+ return any (
502+ context .get_primary () == self .default_regional_routing_context .get_primary ()
503+ for context in self .account_read_regional_routing_contexts_by_location .values ()
504+ )
505+
489506 def can_use_multiple_write_locations (self ):
490507 return self .connection_policy .UseMultipleWriteLocations and self .enable_multiple_writable_locations
491508
0 commit comments