1+ import sys
2+
3+ if sys .version_info [0 ] >= 3 and sys .version_info [1 ] >= 10 :
4+ # Python 3.10 and above
5+ from collections .abc import Iterable
6+ else :
7+ from collections .abc import Iterable
8+
19import copy
210import json
311import os
@@ -153,7 +161,7 @@ def exchange_token(self, uri, code: str) -> dict:
153161 response = self .do_post (uri = uri , body = body , params = None )
154162 resp = response .json ()
155163 jwt_response = self .generate_jwt_response (
156- resp , response .cookies .get (REFRESH_SESSION_COOKIE_NAME )
164+ resp , response .cookies .get (REFRESH_SESSION_COOKIE_NAME ), None
157165 )
158166 return jwt_response
159167
@@ -275,7 +283,7 @@ def exchange_access_key(self, access_key: str) -> dict:
275283 server_response = self .do_post (uri = uri , body = {}, params = None , pswd = access_key )
276284 json = server_response .json ()
277285 return self ._generate_auth_info (
278- response_body = json , refresh_token = None , user_jwt = False
286+ response_body = json , refresh_token = None , user_jwt = False , audience = None
279287 )
280288
281289 @staticmethod
@@ -421,19 +429,25 @@ def adjust_properties(self, jwt_response: dict, user_jwt: bool):
421429 return jwt_response
422430
423431 def _generate_auth_info (
424- self , response_body : dict , refresh_token : str , user_jwt : bool
432+ self ,
433+ response_body : dict ,
434+ refresh_token : str ,
435+ user_jwt : bool ,
436+ audience : str | Iterable [str ] | None = None ,
425437 ) -> dict :
426438 jwt_response = {}
427439 st_jwt = response_body .get ("sessionJwt" , "" )
428440 if st_jwt :
429- jwt_response [SESSION_TOKEN_NAME ] = self ._validate_token (st_jwt )
441+ jwt_response [SESSION_TOKEN_NAME ] = self ._validate_token (st_jwt , audience )
430442 rt_jwt = response_body .get ("refreshJwt" , "" )
431443 if refresh_token :
432444 jwt_response [REFRESH_SESSION_TOKEN_NAME ] = self ._validate_token (
433- refresh_token
445+ refresh_token , audience
434446 )
435447 elif rt_jwt :
436- jwt_response [REFRESH_SESSION_TOKEN_NAME ] = self ._validate_token (rt_jwt )
448+ jwt_response [REFRESH_SESSION_TOKEN_NAME ] = self ._validate_token (
449+ rt_jwt , audience
450+ )
437451
438452 jwt_response = self .adjust_properties (jwt_response , user_jwt )
439453
@@ -447,8 +461,15 @@ def _generate_auth_info(
447461
448462 return jwt_response
449463
450- def generate_jwt_response (self , response_body : dict , refresh_cookie : str ) -> dict :
451- jwt_response = self ._generate_auth_info (response_body , refresh_cookie , True )
464+ def generate_jwt_response (
465+ self ,
466+ response_body : dict ,
467+ refresh_cookie : str ,
468+ audience : str | Iterable [str ] | None = None ,
469+ ) -> dict :
470+ jwt_response = self ._generate_auth_info (
471+ response_body , refresh_cookie , True , audience
472+ )
452473
453474 jwt_response ["user" ] = response_body .get ("user" , {})
454475 jwt_response ["firstSeen" ] = response_body .get ("firstSeen" , True )
@@ -471,7 +492,9 @@ def _get_default_headers(self, pswd: str = None):
471492 return headers
472493
473494 # Validate a token and load the public key if needed
474- def _validate_token (self , token : str ) -> dict :
495+ def _validate_token (
496+ self , token : str , audience : str | Iterable [str ] | None = None
497+ ) -> dict :
475498 if not token :
476499 raise AuthException (
477500 500 ,
@@ -527,6 +550,7 @@ def _validate_token(self, token: str) -> dict:
527550 jwt = token ,
528551 key = copy_key [0 ].key ,
529552 algorithms = [alg_header ],
553+ audience = audience ,
530554 leeway = self .jwt_validation_leeway ,
531555 )
532556 except ImmatureSignatureError :
@@ -539,7 +563,9 @@ def _validate_token(self, token: str) -> dict:
539563 claims ["jwt" ] = token
540564 return claims
541565
542- def validate_session (self , session_token : str ) -> dict :
566+ def validate_session (
567+ self , session_token : str , audience : str | Iterable [str ] | None = None
568+ ) -> dict :
543569 if not session_token :
544570 raise AuthException (
545571 400 ,
@@ -548,7 +574,7 @@ def validate_session(self, session_token: str) -> dict:
548574 )
549575
550576 try :
551- res = self ._validate_token (session_token )
577+ res = self ._validate_token (session_token , audience )
552578 res [SESSION_TOKEN_NAME ] = copy .deepcopy (
553579 res
554580 ) # Duplicate for saving backward compatibility but keep the same structure as the refresh operation response
@@ -560,7 +586,9 @@ def validate_session(self, session_token: str) -> dict:
560586 401 , ERROR_TYPE_INVALID_TOKEN , f"Invalid session token: { e } "
561587 )
562588
563- def refresh_session (self , refresh_token : str ) -> dict :
589+ def refresh_session (
590+ self , refresh_token : str , audience : str | Iterable [str ] | None = None
591+ ) -> dict :
564592 if not refresh_token :
565593 raise AuthException (
566594 400 ,
@@ -569,7 +597,7 @@ def refresh_session(self, refresh_token: str) -> dict:
569597 )
570598
571599 try :
572- self ._validate_token (refresh_token )
600+ self ._validate_token (refresh_token , audience )
573601 except RateLimitException as e :
574602 raise e
575603 except Exception as e :
@@ -582,10 +610,13 @@ def refresh_session(self, refresh_token: str) -> dict:
582610 response = self .do_post (uri = uri , body = {}, params = None , pswd = refresh_token )
583611
584612 resp = response .json ()
585- return self .generate_jwt_response (resp , refresh_token )
613+ return self .generate_jwt_response (resp , refresh_token , audience )
586614
587615 def validate_and_refresh_session (
588- self , session_token : str = None , refresh_token : str = None
616+ self ,
617+ session_token : str = None ,
618+ refresh_token : str = None ,
619+ audience : str | Iterable [str ] | None = None ,
589620 ) -> dict :
590621 if not session_token and not refresh_token :
591622 raise AuthException (
@@ -595,10 +626,10 @@ def validate_and_refresh_session(
595626 )
596627
597628 try :
598- return self .validate_session (session_token )
629+ return self .validate_session (session_token , audience )
599630 except Exception :
600631 # Session is invalid - try to refresh it
601- return self .refresh_session (refresh_token )
632+ return self .refresh_session (refresh_token , audience )
602633
603634 @staticmethod
604635 def extract_masked_address (response : dict , method : DeliveryMethod ) -> str :
0 commit comments