Skip to content

Commit 46af4ec

Browse files
committed
Managed Identity for Arc
1 parent 245b5a5 commit 46af4ec

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

msal/imds.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,18 @@ def _obtain_token(http_client, managed_identity, resource):
128128
managed_identity,
129129
resource,
130130
)
131+
if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ:
132+
if ManagedIdentity.is_user_assigned(managed_identity):
133+
raise ValueError( # Note: Azure Identity for Python raised exception too
134+
"Ignoring managed_identity parameter. "
135+
"Azure Arc supports only system-assigned managed identity, "
136+
"See also "
137+
"https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service")
138+
return _obtain_token_on_arc(
139+
http_client,
140+
os.environ["IDENTITY_ENDPOINT"],
141+
resource,
142+
)
131143
return _obtain_token_on_azure_vm(http_client, managed_identity, resource)
132144

133145

@@ -248,6 +260,44 @@ def _obtain_token_on_service_fabric(
248260
raise
249261

250262

263+
def _obtain_token_on_arc(http_client, endpoint, resource):
264+
# https://learn.microsoft.com/en-us/azure/azure-arc/servers/managed-identity-authentication
265+
logger.debug("Obtaining token via managed identity on Azure Arc")
266+
resp = http_client.get(
267+
endpoint,
268+
params={"api-version": "2020-06-01", "resource": resource},
269+
headers={"Metadata": "true"},
270+
)
271+
www_auth = "www-authenticate" # Header in lower case
272+
challenge = {
273+
# Normalized to lowercase, because header names are case-insensitive
274+
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
275+
k.lower(): v for k, v in resp.headers.items() if k.lower() == www_auth
276+
}.get(www_auth, "").split("=") # Output will be ["Basic realm", "content"]
277+
if not ( # https://datatracker.ietf.org/doc/html/rfc7617#section-2
278+
len(challenge) == 2 and challenge[0].lower() == "basic realm"):
279+
raise ValueError("Irrecognizable WWW-Authenticate header: {}".format(resp.headers))
280+
with open(challenge[1]) as f:
281+
secret = f.read()
282+
response = http_client.get(
283+
endpoint,
284+
params={"api-version": "2020-06-01", "resource": resource},
285+
headers={"Metadata": "true", "Authorization": "Basic {}".format(secret)},
286+
)
287+
payload = json.loads(response.text)
288+
if payload.get("access_token") and payload.get("expires_in"):
289+
# Example: https://learn.microsoft.com/en-us/azure/azure-arc/servers/media/managed-identity-authentication/bash-token-output-example.png
290+
return {
291+
"access_token": payload["access_token"],
292+
"expires_in": int(payload["expires_in"]),
293+
"token_type": payload.get("token_type", "Bearer"),
294+
"resource": payload.get("resource"),
295+
}
296+
return {
297+
"error": "invalid_request",
298+
"error_description": response.text,
299+
}
300+
251301

252302
class ManagedIdentityClient(object):
253303
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders

tests/http_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def close(self): # Not required, but we use it to avoid a warning in unit test
2727
class MinimalResponse(object): # Not for production use
2828
def __init__(self, requests_resp=None, status_code=None, text=None):
2929
self.status_code = status_code or requests_resp.status_code
30-
self.text = text or requests_resp.text
30+
self.text = text if text is not None else requests_resp.text
3131
self._raw_resp = requests_resp
3232

3333
def raise_for_status(self):

tests/test_mi.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import time
44
import unittest
55
try:
6-
from unittest.mock import patch, ANY
6+
from unittest.mock import patch, ANY, mock_open
77
except:
8-
from mock import patch, ANY
8+
from mock import patch, ANY, mock_open
99
import requests
1010

1111
from tests.http_client import MinimalResponse
@@ -59,7 +59,7 @@ def _test_token_cache(self, app):
5959

6060
def _test_happy_path(self, app, mocked_http):
6161
result = app.acquire_token(resource="R")
62-
mocked_http.assert_called_once()
62+
mocked_http.assert_called()
6363
self.assertEqual({
6464
"access_token": "AT",
6565
"expires_in": 1234,
@@ -158,3 +158,23 @@ def test_app_service_error_should_be_normalized(self):
158158
}, self.app.acquire_token(resource="R"))
159159
self.assertEqual({}, self.app._token_cache._cache)
160160

161+
162+
@patch.dict(os.environ, {
163+
"IDENTITY_ENDPOINT": "http://localhost/token",
164+
"IMDS_ENDPOINT": "http://localhost",
165+
})
166+
class ArcTestCase(ClientTestCase):
167+
168+
@patch("builtins.open", mock_open(read_data="secret"))
169+
def test_happy_path(self):
170+
with patch.object(self.app._http_client, "get", side_effect=[
171+
MinimalResponse(status_code=401, text="", headers={
172+
"WWW-Authenticate": "Basic realm=/tmp/foo",
173+
}),
174+
MinimalResponse(
175+
status_code=200,
176+
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
177+
),
178+
]) as mocked_method:
179+
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
180+

0 commit comments

Comments
 (0)