1010import pytest
1111
1212import azure .cosmos .cosmos_client as cosmos_client
13+ from azure .cosmos .container import ContainerProxy
1314import test_config
15+ from _fault_injection_transport import FaultInjectionTransport
16+ from test_fault_injection_transport import TestFaultInjectionTransport
17+ from typing import List , Callable
18+ from azure .core .rest import HttpRequest
1419
1520try :
1621 from unittest .mock import Mock
@@ -30,8 +35,32 @@ def reset(self):
3035 def emit (self , record ):
3136 self .messages .append (record )
3237
33-
34-
38+ CONFIG = test_config .TestConfig
39+ L1 = "Location1"
40+ L2 = "Location2"
41+ L1_URL = test_config .TestConfig .local_host
42+ L2_URL = L1_URL .replace ("localhost" , "127.0.0.1" )
43+ URL_TO_LOCATIONS = {
44+ L1_URL : L1 ,
45+ L2_URL : L2 }
46+
47+
48+ def create_logger (name : str , mock_handler : MockHandler , level : int = logging .INFO ) -> logging .Logger :
49+ logger = logging .getLogger (name )
50+ logger .addHandler (mock_handler )
51+ logger .setLevel (level )
52+
53+ return logger
54+
55+ def get_locations_list (msg : str ) -> List [str ]:
56+ msg = msg .replace (' ' , '' )
57+ msg = msg .replace ('\' ' , '' )
58+ # Find the substring between the first '[' and the last ']'
59+ start = msg .find ('[' ) + 1
60+ end = msg .rfind (']' )
61+ # Extract the substring and convert it to a list using ast.literal_eval
62+ msg = msg [start :end ]
63+ return msg .split (',' )
3564
3665@pytest .mark .cosmosEmulator
3766class TestCosmosHttpLogger (unittest .TestCase ):
@@ -54,12 +83,8 @@ def setUpClass(cls):
5483 "tests." )
5584 cls .mock_handler_default = MockHandler ()
5685 cls .mock_handler_diagnostic = MockHandler ()
57- cls .logger_default = logging .getLogger ("testloggerdefault" )
58- cls .logger_default .addHandler (cls .mock_handler_default )
59- cls .logger_default .setLevel (logging .INFO )
60- cls .logger_diagnostic = logging .getLogger ("testloggerdiagnostic" )
61- cls .logger_diagnostic .addHandler (cls .mock_handler_diagnostic )
62- cls .logger_diagnostic .setLevel (logging .INFO )
86+ cls .logger_default = create_logger ("testloggerdefault" , cls .mock_handler_default )
87+ cls .logger_diagnostic = create_logger ("testloggerdiagnostic" , cls .mock_handler_diagnostic )
6388 cls .client_default = cosmos_client .CosmosClient (cls .host , cls .masterKey ,
6489 consistency_level = "Session" ,
6590 connection_policy = cls .connectionPolicy ,
@@ -136,6 +161,65 @@ def test_cosmos_http_logging_policy(self):
136161
137162 self .mock_handler_diagnostic .reset ()
138163
164+ def test_client_settings (self ):
165+ # Test data
166+ all_locations = [L1 , L2 ]
167+ client_excluded_locations = [L1 ]
168+ multiple_write_locations = True
169+
170+ # Client setup
171+ mock_handler = MockHandler ()
172+ logger = create_logger ("test_logger_client_settings" , mock_handler )
173+
174+ custom_transport = FaultInjectionTransport ()
175+ is_get_account_predicate : Callable [[HttpRequest ], bool ] = lambda \
176+ r : FaultInjectionTransport .predicate_is_database_account_call (r )
177+ emulator_as_multi_write_region_account_transformation = \
178+ lambda r , inner : FaultInjectionTransport .transform_topology_mwr (
179+ first_region_name = L1 ,
180+ second_region_name = L2 ,
181+ inner = inner ,
182+ first_region_url = L1_URL ,
183+ second_region_url = L2_URL ,
184+ )
185+ custom_transport .add_response_transformation (
186+ is_get_account_predicate ,
187+ emulator_as_multi_write_region_account_transformation )
188+
189+ initialized_objects = TestFaultInjectionTransport .setup_method_with_custom_transport (
190+ custom_transport ,
191+ default_endpoint = CONFIG .host ,
192+ key = CONFIG .masterKey ,
193+ database_id = CONFIG .TEST_DATABASE_ID ,
194+ container_id = CONFIG .TEST_SINGLE_PARTITION_CONTAINER_ID ,
195+ preferred_locations = all_locations ,
196+ excluded_locations = client_excluded_locations ,
197+ multiple_write_locations = multiple_write_locations ,
198+ custom_logger = logger
199+ )
200+ mock_handler .reset ()
201+
202+ # create an item
203+ id_value : str = str (uuid .uuid4 ())
204+ document_definition = {'id' : id_value , 'pk' : id_value }
205+ container : ContainerProxy = initialized_objects ["col" ]
206+ container .create_item (body = document_definition )
207+
208+ # Verify endpoint locations
209+ messages_split = mock_handler .messages [1 ].message .split ("\n " )
210+ for message in messages_split :
211+ if "Client Preferred Regions:" in message :
212+ locations = get_locations_list (message )
213+ assert all_locations == locations
214+ elif "Client Excluded Regions:" in message :
215+ locations = get_locations_list (message )
216+ assert client_excluded_locations == locations
217+ elif "Client Account Read Regions:" in message :
218+ locations = get_locations_list (message )
219+ assert all_locations == locations
220+ elif "Client Account Write Regions:" in message :
221+ locations = get_locations_list (message )
222+ assert all_locations == locations
139223
140224if __name__ == "__main__" :
141225 unittest .main ()
0 commit comments