2121import msal .telemetry
2222from .region import _detect_region
2323from .throttled_http_client import ThrottledHttpClient
24+ from .cloudshell import _is_running_in_cloud_shell
2425
2526
2627# The __init__.py will import this. Not the other way around.
27- __version__ = "1.17.0 " # When releasing, also check and bump our dependencies's versions if needed
28+ __version__ = "1.18.0b1 " # When releasing, also check and bump our dependencies's versions if needed
2829
2930logger = logging .getLogger (__name__ )
30-
31+ _AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL"
3132
3233def extract_certs (public_cert_content ):
3334 # Parses raw public certificate file contents and returns a list of strings
@@ -636,6 +637,7 @@ def initiate_auth_code_flow(
636637 domain_hint = None , # type: Optional[str]
637638 claims_challenge = None ,
638639 max_age = None ,
640+ response_mode = None , # type: Optional[str]
639641 ):
640642 """Initiate an auth code flow.
641643
@@ -677,6 +679,20 @@ def initiate_auth_code_flow(
677679
678680 New in version 1.15.
679681
682+ :param str response_mode:
683+ OPTIONAL. Specifies the method with which response parameters should be returned.
684+ The default value is equivalent to ``query``, which is still secure enough in MSAL Python
685+ (because MSAL Python does not transfer tokens via query parameter in the first place).
686+ For even better security, we recommend using the value ``form_post``.
687+ In "form_post" mode, response parameters
688+ will be encoded as HTML form values that are transmitted via the HTTP POST method and
689+ encoded in the body using the application/x-www-form-urlencoded format.
690+ Valid values can be either "form_post" for HTTP POST to callback URI or
691+ "query" (the default) for HTTP GET with parameters encoded in query string.
692+ More information on possible values
693+ `here <https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#ResponseModes>`
694+ and `here <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html#FormPostResponseMode>`
695+
680696 :return:
681697 The auth code flow. It is a dict in this form::
682698
@@ -707,6 +723,7 @@ def initiate_auth_code_flow(
707723 claims = _merge_claims_challenge_and_capabilities (
708724 self ._client_capabilities , claims_challenge ),
709725 max_age = max_age ,
726+ response_mode = response_mode ,
710727 )
711728 flow ["claims_challenge" ] = claims_challenge
712729 return flow
@@ -970,6 +987,10 @@ def get_accounts(self, username=None):
970987 return accounts
971988
972989 def _find_msal_accounts (self , environment ):
990+ interested_authority_types = [
991+ TokenCache .AuthorityType .ADFS , TokenCache .AuthorityType .MSSTS ]
992+ if _is_running_in_cloud_shell ():
993+ interested_authority_types .append (_AUTHORITY_TYPE_CLOUDSHELL )
973994 grouped_accounts = {
974995 a .get ("home_account_id" ): # Grouped by home tenant's id
975996 { # These are minimal amount of non-tenant-specific account info
@@ -985,8 +1006,7 @@ def _find_msal_accounts(self, environment):
9851006 for a in self .token_cache .find (
9861007 TokenCache .CredentialType .ACCOUNT ,
9871008 query = {"environment" : environment })
988- if a ["authority_type" ] in (
989- TokenCache .AuthorityType .ADFS , TokenCache .AuthorityType .MSSTS )
1009+ if a ["authority_type" ] in interested_authority_types
9901010 }
9911011 return list (grouped_accounts .values ())
9921012
@@ -1046,6 +1066,21 @@ def _forget_me(self, home_account):
10461066 TokenCache .CredentialType .ACCOUNT , query = owned_by_home_account ):
10471067 self .token_cache .remove_account (a )
10481068
1069+ def _acquire_token_by_cloud_shell (self , scopes , data = None ):
1070+ from .cloudshell import _obtain_token
1071+ response = _obtain_token (
1072+ self .http_client , scopes , client_id = self .client_id , data = data )
1073+ if "error" not in response :
1074+ self .token_cache .add (dict (
1075+ client_id = self .client_id ,
1076+ scope = response ["scope" ].split () if "scope" in response else scopes ,
1077+ token_endpoint = self .authority .token_endpoint ,
1078+ response = response .copy (),
1079+ data = data or {},
1080+ authority_type = _AUTHORITY_TYPE_CLOUDSHELL ,
1081+ ))
1082+ return response
1083+
10491084 def acquire_token_silent (
10501085 self ,
10511086 scopes , # type: List[str]
@@ -1179,6 +1214,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
11791214 authority , # This can be different than self.authority
11801215 force_refresh = False , # type: Optional[boolean]
11811216 claims_challenge = None ,
1217+ correlation_id = None ,
11821218 ** kwargs ):
11831219 access_token_from_cache = None
11841220 if not (force_refresh or claims_challenge ): # Bypass AT when desired or using claims
@@ -1217,9 +1253,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12171253 refresh_reason = msal .telemetry .FORCE_REFRESH # TODO: It could also mean claims_challenge
12181254 assert refresh_reason , "It should have been established at this point"
12191255 try :
1256+ if account and account .get ("authority_type" ) == _AUTHORITY_TYPE_CLOUDSHELL :
1257+ return self ._acquire_token_by_cloud_shell (
1258+ scopes , data = kwargs .get ("data" ))
12201259 result = _clean_up (self ._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family (
12211260 authority , self ._decorate_scope (scopes ), account ,
12221261 refresh_reason = refresh_reason , claims_challenge = claims_challenge ,
1262+ correlation_id = correlation_id ,
12231263 ** kwargs ))
12241264 if (result and "error" not in result ) or (not access_token_from_cache ):
12251265 return result
@@ -1558,6 +1598,9 @@ def acquire_token_interactive(
15581598 - A dict containing an "error" key, when token refresh failed.
15591599 """
15601600 self ._validate_ssh_cert_input_data (kwargs .get ("data" , {}))
1601+ if _is_running_in_cloud_shell () and prompt == "none" :
1602+ return self ._acquire_token_by_cloud_shell (
1603+ scopes , data = kwargs .pop ("data" , {}))
15611604 claims = _merge_claims_challenge_and_capabilities (
15621605 self ._client_capabilities , claims_challenge )
15631606 telemetry_context = self ._build_telemetry_context (
@@ -1659,6 +1702,11 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
16591702 - an error response would contain "error" and usually "error_description".
16601703 """
16611704 # TBD: force_refresh behavior
1705+ if self .authority .tenant .lower () in ["common" , "organizations" ]:
1706+ warnings .warn (
1707+ "Using /common or /organizations authority "
1708+ "in acquire_token_for_client() is unreliable. "
1709+ "Please use a specific tenant instead." , DeprecationWarning )
16621710 self ._validate_ssh_cert_input_data (kwargs .get ("data" , {}))
16631711 telemetry_context = self ._build_telemetry_context (
16641712 self .ACQUIRE_TOKEN_FOR_CLIENT_ID )
0 commit comments