1414import azure .cosmos .cosmos_client as cosmos_client
1515import test_config
1616from azure .cosmos import DatabaseProxy , ContainerProxy , exceptions
17-
17+ from azure . core . exceptions import HttpResponseError
1818
1919def _remove_padding (encoded_string ):
2020 while encoded_string .endswith ("=" ):
@@ -34,7 +34,6 @@ def get_test_item(num):
3434
3535
3636class CosmosEmulatorCredential (object ):
37-
3837 def get_token (self , * scopes , ** kwargs ):
3938 # type: (*str, **Any) -> AccessToken
4039 """Request an access token for the emulator. Based on Azure Core's Access Token Credential.
@@ -118,33 +117,126 @@ def test_aad_credentials(self):
118117 assert e .status_code == 403
119118 print ("403 error assertion success" )
120119
121- def test_aad_scope_override (self ):
122- override_scope = "https://my.custom.scope/.default"
123- os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = override_scope
124120
121+ def _run_with_scope_capture (self , credential_cls , action , * args , ** kwargs ):
125122 scopes_captured = []
126- original_get_token = CosmosEmulatorCredential .get_token
123+ original_get_token = credential_cls .get_token
127124
128125 def capturing_get_token (self , * scopes , ** kwargs ):
129126 scopes_captured .extend (scopes )
130127 return original_get_token (self , * scopes , ** kwargs )
131128
132- CosmosEmulatorCredential .get_token = capturing_get_token
133-
129+ credential_cls .get_token = capturing_get_token
134130 try :
131+ result = action (scopes_captured , * args , ** kwargs )
132+ finally :
133+ credential_cls .get_token = original_get_token
134+ return scopes_captured , result
135+
136+ def test_override_scope_no_fallback (self ):
137+ """When override scope is provided, only that scope is used and no fallback occurs."""
138+ override_scope = "https://my.custom.scope/.default"
139+ os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = override_scope
140+
141+ def action (scopes_captured ):
135142 credential = CosmosEmulatorCredential ()
136143 client = cosmos_client .CosmosClient (self .host , credential )
137144 db = client .get_database_client (self .configs .TEST_DATABASE_ID )
138145 container = db .get_container_client (self .configs .TEST_SINGLE_PARTITION_CONTAINER_ID )
139- container .create_item (get_test_item (1 ))
140- assert override_scope in scopes_captured
146+ container .create_item (get_test_item (10 ))
147+ return container
148+
149+ scopes , container = self ._run_with_scope_capture (CosmosEmulatorCredential , action )
150+ try :
151+ assert all (scope == override_scope for scope in scopes ), f"Expected only override scope(s), got: { scopes } "
141152 finally :
142- CosmosEmulatorCredential .get_token = original_get_token
143153 del os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ]
144154 try :
145- container .delete_item (item = 'Item_1 ' , partition_key = 'pk' )
155+ container .delete_item (item = 'Item_10 ' , partition_key = 'pk' )
146156 except Exception :
147157 pass
148158
159+ def test_override_scope_auth_error_no_fallback (self ):
160+ """When override scope is provided and auth fails, no fallback to other scopes occurs."""
161+ override_scope = "https://my.custom.scope/.default"
162+ os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = override_scope
163+
164+ class FailingCredential (CosmosEmulatorCredential ):
165+ def get_token (self , * scopes , ** kwargs ):
166+ raise Exception ("Simulated auth error for override scope" )
167+
168+ def action (scopes_captured ):
169+ with pytest .raises (Exception ) as excinfo :
170+ client = cosmos_client .CosmosClient (self .host , FailingCredential ())
171+ db = client .get_database_client (self .configs .TEST_DATABASE_ID )
172+ container = db .get_container_client (self .configs .TEST_SINGLE_PARTITION_CONTAINER_ID )
173+ container .create_item (get_test_item (11 ))
174+ assert "Simulated auth error" in str (excinfo .value )
175+ return None
176+
177+ scopes , _ = self ._run_with_scope_capture (FailingCredential , action )
178+ try :
179+ assert scopes == [override_scope ], f"Expected only override scope, got: { scopes } "
180+ finally :
181+ del os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ]
182+
183+ def test_account_scope_only (self ):
184+ """When account scope is provided, only that scope is used."""
185+ account_scope = "https://localhost/.default"
186+ os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = ""
187+
188+ def action (scopes_captured ):
189+ credential = CosmosEmulatorCredential ()
190+ client = cosmos_client .CosmosClient (self .host , credential )
191+ db = client .get_database_client (self .configs .TEST_DATABASE_ID )
192+ container = db .get_container_client (self .configs .TEST_SINGLE_PARTITION_CONTAINER_ID )
193+ container .create_item (get_test_item (12 ))
194+ return container
195+
196+ scopes , container = self ._run_with_scope_capture (CosmosEmulatorCredential , action )
197+ try :
198+ # Accept multiple calls, but only the account_scope should be used
199+ assert all (scope == account_scope for scope in scopes ), f"Expected only account scope, got: { scopes } "
200+ finally :
201+ try :
202+ container .delete_item (item = 'Item_12' , partition_key = 'pk' )
203+ except Exception :
204+ pass
205+
206+ def test_account_scope_fallback_on_error (self ):
207+ """When account scope is provided and auth fails, fallback to default scope occurs."""
208+ account_scope = "https://localhost/.default"
209+ fallback_scope = "https://cosmos.azure.com/.default"
210+ os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = ""
211+
212+ class FallbackCredential (CosmosEmulatorCredential ):
213+ def __init__ (self ):
214+ self .call_count = 0
215+
216+ def get_token (self , * scopes , ** kwargs ):
217+ self .call_count += 1
218+ if self .call_count == 1 :
219+ raise HttpResponseError (message = "AADSTS500011: Simulated error for fallback" )
220+ return super ().get_token (* scopes , ** kwargs )
221+
222+ def action (scopes_captured ):
223+ credential = FallbackCredential ()
224+ client = cosmos_client .CosmosClient (self .host , credential )
225+ db = client .get_database_client (self .configs .TEST_DATABASE_ID )
226+ container = db .get_container_client (self .configs .TEST_SINGLE_PARTITION_CONTAINER_ID )
227+ container .create_item (get_test_item (13 ))
228+ return container
229+
230+ scopes , container = self ._run_with_scope_capture (FallbackCredential , action )
231+ try :
232+ # Accept multiple calls, but the first should be account_scope, and fallback_scope should appear after error
233+ assert account_scope in scopes and fallback_scope in scopes , f"Expected fallback to default scope, got: { scopes } "
234+ finally :
235+ try :
236+ container .delete_item (item = 'Item_13' , partition_key = 'pk' )
237+ except Exception :
238+ pass
239+
240+
149241if __name__ == "__main__" :
150242 unittest .main ()
0 commit comments