14
14
import azure .cosmos .cosmos_client as cosmos_client
15
15
import test_config
16
16
from azure .cosmos import DatabaseProxy , ContainerProxy , exceptions
17
-
17
+ from azure . core . exceptions import HttpResponseError
18
18
19
19
def _remove_padding (encoded_string ):
20
20
while encoded_string .endswith ("=" ):
@@ -34,7 +34,6 @@ def get_test_item(num):
34
34
35
35
36
36
class CosmosEmulatorCredential (object ):
37
-
38
37
def get_token (self , * scopes , ** kwargs ):
39
38
# type: (*str, **Any) -> AccessToken
40
39
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.
@@ -118,33 +117,126 @@ def test_aad_credentials(self):
118
117
assert e .status_code == 403
119
118
print ("403 error assertion success" )
120
119
121
- def test_aad_scope_override (self ):
122
- override_scope = "https://my.custom.scope/.default"
123
- os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = override_scope
124
120
121
+ def _run_with_scope_capture (self , credential_cls , action , * args , ** kwargs ):
125
122
scopes_captured = []
126
- original_get_token = CosmosEmulatorCredential .get_token
123
+ original_get_token = credential_cls .get_token
127
124
128
125
def capturing_get_token (self , * scopes , ** kwargs ):
129
126
scopes_captured .extend (scopes )
130
127
return original_get_token (self , * scopes , ** kwargs )
131
128
132
- CosmosEmulatorCredential .get_token = capturing_get_token
133
-
129
+ credential_cls .get_token = capturing_get_token
134
130
try :
131
+ result = action (scopes_captured , * args , ** kwargs )
132
+ finally :
133
+ credential_cls .get_token = original_get_token
134
+ return scopes_captured , result
135
+
136
+ def test_override_scope_no_fallback (self ):
137
+ """When override scope is provided, only that scope is used and no fallback occurs."""
138
+ override_scope = "https://my.custom.scope/.default"
139
+ os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = override_scope
140
+
141
+ def action (scopes_captured ):
135
142
credential = CosmosEmulatorCredential ()
136
143
client = cosmos_client .CosmosClient (self .host , credential )
137
144
db = client .get_database_client (self .configs .TEST_DATABASE_ID )
138
145
container = db .get_container_client (self .configs .TEST_SINGLE_PARTITION_CONTAINER_ID )
139
- container .create_item (get_test_item (1 ))
140
- assert override_scope in scopes_captured
146
+ container .create_item (get_test_item (10 ))
147
+ return container
148
+
149
+ scopes , container = self ._run_with_scope_capture (CosmosEmulatorCredential , action )
150
+ try :
151
+ assert all (scope == override_scope for scope in scopes ), f"Expected only override scope(s), got: { scopes } "
141
152
finally :
142
- CosmosEmulatorCredential .get_token = original_get_token
143
153
del os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ]
144
154
try :
145
- container .delete_item (item = 'Item_1 ' , partition_key = 'pk' )
155
+ container .delete_item (item = 'Item_10 ' , partition_key = 'pk' )
146
156
except Exception :
147
157
pass
148
158
159
+ def test_override_scope_auth_error_no_fallback (self ):
160
+ """When override scope is provided and auth fails, no fallback to other scopes occurs."""
161
+ override_scope = "https://my.custom.scope/.default"
162
+ os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = override_scope
163
+
164
+ class FailingCredential (CosmosEmulatorCredential ):
165
+ def get_token (self , * scopes , ** kwargs ):
166
+ raise Exception ("Simulated auth error for override scope" )
167
+
168
+ def action (scopes_captured ):
169
+ with pytest .raises (Exception ) as excinfo :
170
+ client = cosmos_client .CosmosClient (self .host , FailingCredential ())
171
+ db = client .get_database_client (self .configs .TEST_DATABASE_ID )
172
+ container = db .get_container_client (self .configs .TEST_SINGLE_PARTITION_CONTAINER_ID )
173
+ container .create_item (get_test_item (11 ))
174
+ assert "Simulated auth error" in str (excinfo .value )
175
+ return None
176
+
177
+ scopes , _ = self ._run_with_scope_capture (FailingCredential , action )
178
+ try :
179
+ assert scopes == [override_scope ], f"Expected only override scope, got: { scopes } "
180
+ finally :
181
+ del os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ]
182
+
183
+ def test_account_scope_only (self ):
184
+ """When account scope is provided, only that scope is used."""
185
+ account_scope = "https://localhost/.default"
186
+ os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = ""
187
+
188
+ def action (scopes_captured ):
189
+ credential = CosmosEmulatorCredential ()
190
+ client = cosmos_client .CosmosClient (self .host , credential )
191
+ db = client .get_database_client (self .configs .TEST_DATABASE_ID )
192
+ container = db .get_container_client (self .configs .TEST_SINGLE_PARTITION_CONTAINER_ID )
193
+ container .create_item (get_test_item (12 ))
194
+ return container
195
+
196
+ scopes , container = self ._run_with_scope_capture (CosmosEmulatorCredential , action )
197
+ try :
198
+ # Accept multiple calls, but only the account_scope should be used
199
+ assert all (scope == account_scope for scope in scopes ), f"Expected only account scope, got: { scopes } "
200
+ finally :
201
+ try :
202
+ container .delete_item (item = 'Item_12' , partition_key = 'pk' )
203
+ except Exception :
204
+ pass
205
+
206
+ def test_account_scope_fallback_on_error (self ):
207
+ """When account scope is provided and auth fails, fallback to default scope occurs."""
208
+ account_scope = "https://localhost/.default"
209
+ fallback_scope = "https://cosmos.azure.com/.default"
210
+ os .environ ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE" ] = ""
211
+
212
+ class FallbackCredential (CosmosEmulatorCredential ):
213
+ def __init__ (self ):
214
+ self .call_count = 0
215
+
216
+ def get_token (self , * scopes , ** kwargs ):
217
+ self .call_count += 1
218
+ if self .call_count == 1 :
219
+ raise HttpResponseError (message = "AADSTS500011: Simulated error for fallback" )
220
+ return super ().get_token (* scopes , ** kwargs )
221
+
222
+ def action (scopes_captured ):
223
+ credential = FallbackCredential ()
224
+ client = cosmos_client .CosmosClient (self .host , credential )
225
+ db = client .get_database_client (self .configs .TEST_DATABASE_ID )
226
+ container = db .get_container_client (self .configs .TEST_SINGLE_PARTITION_CONTAINER_ID )
227
+ container .create_item (get_test_item (13 ))
228
+ return container
229
+
230
+ scopes , container = self ._run_with_scope_capture (FallbackCredential , action )
231
+ try :
232
+ # Accept multiple calls, but the first should be account_scope, and fallback_scope should appear after error
233
+ assert account_scope in scopes and fallback_scope in scopes , f"Expected fallback to default scope, got: { scopes } "
234
+ finally :
235
+ try :
236
+ container .delete_item (item = 'Item_13' , partition_key = 'pk' )
237
+ except Exception :
238
+ pass
239
+
240
+
149
241
if __name__ == "__main__" :
150
242
unittest .main ()
0 commit comments