7
7
import aiohttp
8
8
from azure .search .documents .aio import SearchClient
9
9
from azure .search .documents .indexes .models import SearchIndex
10
+ from jose import jwt
10
11
from msal import ConfidentialClientApplication
11
12
from msal .token_cache import TokenCache
13
+ from tenacity import (
14
+ AsyncRetrying ,
15
+ retry_if_exception_type ,
16
+ stop_after_attempt ,
17
+ wait_random_exponential ,
18
+ )
12
19
13
20
14
21
# AuthError is raised when the authentication token sent by the client UI cannot be parsed or there is an authentication error accessing the graph API
@@ -40,6 +47,15 @@ def __init__(
40
47
self .client_app_id = client_app_id
41
48
self .tenant_id = tenant_id
42
49
self .authority = f"https://login.microsoftonline.com/{ tenant_id } "
50
+ # Depending on if requestedAccessTokenVersion is 1 or 2, the issuer and audience of the token may be different
51
+ # See https://learn.microsoft.com/graph/api/resources/apiapplication
52
+ self .valid_issuers = [
53
+ f"https://sts.windows.net/{ tenant_id } /" ,
54
+ f"https://login.microsoftonline.com/{ tenant_id } /v2.0" ,
55
+ ]
56
+ self .valid_audiences = [f"api://{ server_app_id } " , str (server_app_id )]
57
+ # See https://learn.microsoft.com/entra/identity-platform/access-tokens#validate-the-issuer for more information on token validation
58
+ self .key_url = f"{ self .authority } /discovery/v2.0/keys"
43
59
44
60
if self .use_authentication :
45
61
field_names = [field .name for field in search_index .fields ] if search_index else []
@@ -182,6 +198,11 @@ async def get_auth_claims_if_enabled(self, headers: dict) -> dict[str, Any]:
182
198
# The scope is set to the Microsoft Graph API, which may need to be called for more authorization information
183
199
# https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow
184
200
auth_token = AuthenticationHelper .get_token_auth_header (headers )
201
+ # Validate the token before use
202
+ await self .validate_access_token (auth_token )
203
+
204
+ # Use the on-behalf-of-flow to acquire another token for use with Microsoft Graph
205
+ # See https://learn.microsoft.com/entra/identity-platform/v2-oauth2-on-behalf-of-flow for more information
185
206
graph_resource_access_token = self .confidential_client .acquire_token_on_behalf_of (
186
207
user_assertion = auth_token , scopes = ["https://graph.microsoft.com/.default" ]
187
208
)
@@ -207,7 +228,6 @@ async def get_auth_claims_if_enabled(self, headers: dict) -> dict[str, Any]:
207
228
auth_claims ["groups" ] = await AuthenticationHelper .list_groups (graph_resource_access_token )
208
229
return auth_claims
209
230
except AuthError as e :
210
- print (e .error )
211
231
logging .exception ("Exception getting authorization information - " + json .dumps (e .error ))
212
232
if self .require_access_control :
213
233
raise
@@ -237,3 +257,74 @@ async def check_path_auth(self, path: str, auth_claims: dict[str, Any], search_c
237
257
break
238
258
239
259
return allowed
260
+
261
+ # See https://github.com/Azure-Samples/ms-identity-python-on-behalf-of/blob/939be02b11f1604814532fdacc2c2eccd198b755/FlaskAPI/helpers/authorization.py#L44
262
+ async def validate_access_token (self , token : str ):
263
+ """
264
+ Validate an access token is issued by Entra
265
+ """
266
+ jwks = None
267
+ async for attempt in AsyncRetrying (
268
+ retry = retry_if_exception_type (AuthError ),
269
+ wait = wait_random_exponential (min = 15 , max = 60 ),
270
+ stop = stop_after_attempt (5 ),
271
+ ):
272
+ with attempt :
273
+ async with aiohttp .ClientSession () as session :
274
+ async with session .get (url = self .key_url ) as resp :
275
+ resp_status = resp .status
276
+ if resp_status in [500 , 502 , 503 , 504 ]:
277
+ raise AuthError (
278
+ error = f"Failed to get keys info: { await resp .text ()} " , status_code = resp_status
279
+ )
280
+ jwks = await resp .json ()
281
+
282
+ if not jwks or "keys" not in jwks :
283
+ raise AuthError ({"code" : "invalid_keys" , "description" : "Unable to get keys to validate auth token." }, 401 )
284
+
285
+ rsa_key = None
286
+ issuer = None
287
+ audience = None
288
+ try :
289
+ unverified_header = jwt .get_unverified_header (token )
290
+ unverified_claims = jwt .get_unverified_claims (token )
291
+ issuer = unverified_claims .get ("iss" )
292
+ audience = unverified_claims .get ("aud" )
293
+ for key in jwks ["keys" ]:
294
+ if key ["kid" ] == unverified_header ["kid" ]:
295
+ rsa_key = {"kty" : key ["kty" ], "kid" : key ["kid" ], "use" : key ["use" ], "n" : key ["n" ], "e" : key ["e" ]}
296
+ break
297
+ except Exception as exc :
298
+ raise AuthError (
299
+ {"code" : "invalid_header" , "description" : "Unable to parse authorization token." }, 401
300
+ ) from exc
301
+ if not rsa_key :
302
+ raise AuthError ({"code" : "invalid_header" , "description" : "Unable to find appropriate key" }, 401 )
303
+
304
+ if issuer not in self .valid_issuers :
305
+ raise AuthError (
306
+ {"code" : "invalid_header" , "description" : f"Issuer { issuer } not in { ',' .join (self .valid_issuers )} " }, 401
307
+ )
308
+
309
+ if audience not in self .valid_audiences :
310
+ raise AuthError (
311
+ {
312
+ "code" : "invalid_header" ,
313
+ "description" : f"Audience { audience } not in { ',' .join (self .valid_audiences )} " ,
314
+ },
315
+ 401 ,
316
+ )
317
+
318
+ try :
319
+ jwt .decode (token , rsa_key , algorithms = ["RS256" ], audience = audience , issuer = issuer )
320
+ except jwt .ExpiredSignatureError as jwt_expired_exc :
321
+ raise AuthError ({"code" : "token_expired" , "description" : "token is expired" }, 401 ) from jwt_expired_exc
322
+ except jwt .JWTClaimsError as jwt_claims_exc :
323
+ raise AuthError (
324
+ {"code" : "invalid_claims" , "description" : "incorrect claims," "please check the audience and issuer" },
325
+ 401 ,
326
+ ) from jwt_claims_exc
327
+ except Exception as exc :
328
+ raise AuthError (
329
+ {"code" : "invalid_header" , "description" : "Unable to parse authorization token." }, 401
330
+ ) from exc
0 commit comments