Skip to content

Commit dd75a28

Browse files
committed
Dedicate ManagedIdentity API
1 parent 0c57056 commit dd75a28

File tree

4 files changed

+118
-43
lines changed

4 files changed

+118
-43
lines changed

msal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@
3333
)
3434
from .oauth2cli.oidc import Prompt
3535
from .token_cache import TokenCache, SerializableTokenCache
36+
from .imds import ManagedIdentity
3637

msal/application.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2001,21 +2001,6 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
20012001
- an error response would contain "error" and usually "error_description".
20022002
"""
20032003
# TBD: force_refresh behavior
2004-
if self.client_credential is None:
2005-
from .imds import _scope_to_resource, _obtain_token
2006-
response = _obtain_token(
2007-
self.http_client,
2008-
" ".join(map(_scope_to_resource, scopes)),
2009-
client_id=self.client_id, # None for system-assigned, GUID for user-assigned
2010-
)
2011-
if "error" not in response:
2012-
self.token_cache.add(dict(
2013-
client_id=self.client_id,
2014-
scope=response["scope"].split() if "scope" in response else scopes,
2015-
token_endpoint=self.authority.token_endpoint,
2016-
response=response.copy(),
2017-
))
2018-
return response
20192004
if self.authority.tenant.lower() in ["common", "organizations"]:
20202005
warnings.warn(
20212006
"Using /common or /organizations authority "

msal/imds.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77
import os
8+
import socket
89
import time
910
try: # Python 2
1011
from urlparse import urlparse
@@ -57,6 +58,9 @@ def _obtain_token_on_azure_vm(http_client, resource, client_id=None):
5758
raise
5859

5960
def _obtain_token_on_app_service(http_client, endpoint, identity_header, resource, client_id=None):
61+
"""Obtains token for
62+
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_
63+
"""
6064
# Prerequisite: Create your app service https://docs.microsoft.com/en-us/azure/app-service/quickstart-python
6165
# Assign it a managed identity https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp
6266
# SSH into your container for testing https://docs.microsoft.com/en-us/azure/app-service/configure-linux-open-ssh-session
@@ -73,7 +77,7 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
7377
headers={
7478
"X-IDENTITY-HEADER": identity_header,
7579
"Metadata": "true", # Unnecessary yet harmless for App Service,
76-
# It will be needed by Azure Automation
80+
# It will be needed by Azure Automation
7781
# https://docs.microsoft.com/en-us/azure/automation/enable-managed-identity-for-automation#get-access-token-for-system-assigned-managed-identity-using-http-get
7882
},
7983
)
@@ -95,3 +99,66 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
9599
logger.debug("IMDS emits unexpected payload: %s", resp.text)
96100
raise
97101

102+
103+
class ManagedIdentity(object):
104+
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders
105+
106+
def __init__(self, http_client, client_id=None, token_cache=None):
107+
"""Create a managed identity object.
108+
109+
:param http_client:
110+
An http client object. For example, you can use `requests.Session()`.
111+
112+
:param str client_id:
113+
Optional.
114+
It accepts the Client ID (NOT the Object ID) of your user-assigned managed identity.
115+
If it is None, it means to use a system-assigned managed identity.
116+
117+
:param token_cache:
118+
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
119+
"""
120+
self._http_client = http_client
121+
self._client_id = client_id
122+
self._token_cache = token_cache
123+
124+
def acquire_token(self, resource):
125+
access_token_from_cache = None
126+
if self._token_cache:
127+
matches = self._token_cache.find(
128+
self._token_cache.CredentialType.ACCESS_TOKEN,
129+
target=[resource],
130+
query=dict(
131+
client_id=self._client_id,
132+
environment=self._instance,
133+
realm=self._tenant,
134+
home_account_id=None,
135+
),
136+
)
137+
now = time.time()
138+
for entry in matches:
139+
expires_in = int(entry["expires_on"]) - now
140+
if expires_in < 5*60: # Then consider it expired
141+
continue # Removal is not necessary, it will be overwritten
142+
logger.debug("Cache hit an AT")
143+
access_token_from_cache = { # Mimic a real response
144+
"access_token": entry["secret"],
145+
"token_type": entry.get("token_type", "Bearer"),
146+
"expires_in": int(expires_in), # OAuth2 specs defines it as int
147+
}
148+
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
149+
break # With a fallback in hand, we break here to go refresh
150+
return access_token_from_cache # It is still good as new
151+
result = _obtain_token(self._http_client, resource, client_id=self._client_id)
152+
if self._token_cache and "access_token" in result:
153+
self._token_cache.add(dict(
154+
client_id=self._client_id,
155+
scope=[resource],
156+
token_endpoint="https://{}/{}".format(self._instance, self._tenant),
157+
response=result,
158+
params={},
159+
data={},
160+
#grant_type="placeholder",
161+
))
162+
return result
163+
return access_token_from_cache or result
164+

tests/msaltest.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import getpass, logging, pprint, sys, msal
1+
import functools, getpass, logging, pprint, sys, requests, msal
22

33

44
AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
@@ -141,32 +141,55 @@ def remove_account(app):
141141
app.remove_account(account)
142142
print('Account "{}" and/or its token(s) are signed out from MSAL Python'.format(account["username"]))
143143

144-
def acquire_token_for_client(app):
145-
"""acquire_token_for_client() - Only for confidential client"""
146-
pprint.pprint(app.acquire_token_for_client(_input_scopes()))
144+
def acquire_token_for_managed_identity(app):
145+
"""acquire_token() - Only for managed identity"""
146+
pprint.pprint(app.acquire_token(_select_options([
147+
"https://management.azure.com",
148+
"https://graph.microsoft.com",
149+
],
150+
header="Acquire token for this resource",
151+
accept_nonempty_string=True)))
147152

148153
def exit(app):
149154
"""Exit"""
150155
bug_link = (
151156
"https://identitydivision.visualstudio.com/Engineering/_queries/query/79b3a352-a775-406f-87cd-a487c382a8ed/"
152-
if app._enable_broker else
157+
if getattr(app, "_enable_broker", None) else
153158
"https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/new/choose"
154159
)
155160
print("Bye. If you found a bug, please report it here: {}".format(bug_link))
156161
sys.exit()
157162

163+
def _managed_identity():
164+
client_id = _select_options([
165+
{"client_id": None, "name": "System-assigned managed identity"},
166+
],
167+
option_renderer=lambda a: a["name"],
168+
header="Choose the system-assigned managed identity "
169+
"(or type in your user-assigned managed identity)",
170+
accept_nonempty_string=True)
171+
return msal.ManagedIdentity(
172+
requests.Session(),
173+
client_id=client_id["client_id"]
174+
if isinstance(client_id, dict) else client_id,
175+
token_cache=msal.TokenCache(),
176+
)
177+
158178
def main():
159-
print("Welcome to the Msal Python Console Test App, committed at 2022-5-2\n")
179+
print("Welcome to the Console Test App for MSAL Python {}\n".format(msal.__version__))
160180
chosen_app = _select_options([
161181
{"client_id": AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
162182
{"client_id": VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
163183
{"client_id": "95de633a-083e-42f5-b444-a4295d8e9314", "name": "Whiteboard Services (Non MSA-PT app. Accepts AAD & MSA accounts.)"},
164-
{"client_id": None, "client_secret": None, "name": "System-assigned Managed Identity (Only works when running inside a supported environment, such as Azure VM, Azure App Service, Azure Automation)"},
184+
{"test_managed_identity": None, "name": "Managed Identity (Only works when running inside a supported environment, such as Azure VM, Azure App Service, Azure Automation)"},
165185
],
166186
option_renderer=lambda a: a["name"],
167187
header="Impersonate this app (or you can type in the client_id of your own app)",
168188
accept_nonempty_string=True)
169-
authority = _select_options([
189+
if isinstance(chosen_app, dict) and "test_managed_identity" in chosen_app:
190+
app = _managed_identity()
191+
else:
192+
authority = _select_options([
170193
"https://login.microsoftonline.com/common",
171194
"https://login.microsoftonline.com/organizations",
172195
"https://login.microsoftonline.com/microsoft.onmicrosoft.com",
@@ -175,33 +198,32 @@ def main():
175198
],
176199
header="Input authority (Note that MSA-PT apps would NOT use the /common authority)",
177200
accept_nonempty_string=True,
178-
)
179-
if isinstance(chosen_app, dict) and "client_secret" in chosen_app:
180-
app = msal.ConfidentialClientApplication(
181-
chosen_app["client_id"],
182-
client_credential=chosen_app["client_secret"],
183-
authority=authority,
184-
)
185-
else:
201+
)
186202
app = msal.PublicClientApplication(
187203
chosen_app["client_id"] if isinstance(chosen_app, dict) else chosen_app,
188204
authority=authority,
189205
allow_broker=_input_boolean("Allow broker? (Azure CLI currently only supports @microsoft.com accounts when enabling broker)"),
190206
)
191207
if _input_boolean("Enable MSAL Python's DEBUG log?"):
192208
logging.basicConfig(level=logging.DEBUG)
209+
methods_to_be_tested = functools.reduce(lambda x, y: x + y, [
210+
methods for app_type, methods in {
211+
msal.PublicClientApplication: [
212+
acquire_token_interactive,
213+
acquire_ssh_cert_silently,
214+
acquire_ssh_cert_interactive,
215+
],
216+
msal.ClientApplication: [
217+
acquire_token_silent,
218+
acquire_token_by_username_password,
219+
remove_account,
220+
],
221+
msal.ManagedIdentity: [acquire_token_for_managed_identity],
222+
}.items() if isinstance(app, app_type)])
193223
while True:
194-
func = _select_options(list(filter(None, [
195-
acquire_token_silent,
196-
acquire_token_interactive,
197-
acquire_token_by_username_password,
198-
acquire_ssh_cert_silently,
199-
acquire_ssh_cert_interactive,
200-
remove_account,
201-
acquire_token_for_client if isinstance(
202-
app, msal.ConfidentialClientApplication) else None,
203-
exit,
204-
])), option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:")
224+
func = _select_options(
225+
methods_to_be_tested + [exit],
226+
option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:")
205227
try:
206228
func(app)
207229
except ValueError as e:

0 commit comments

Comments
 (0)