Skip to content

Commit 2e6dc04

Browse files
committed
ManagedIdentityClient will send claims and token_sha256_to_refresh to SF
1 parent 50a5034 commit 2e6dc04

File tree

2 files changed

+93
-14
lines changed

2 files changed

+93
-14
lines changed

msal/managed_identity.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# All rights reserved.
33
#
44
# This code is licensed under the MIT License.
5+
import hashlib
56
import json
67
import logging
78
import os
@@ -266,8 +267,7 @@ def acquire_token_for_client(
266267
and then a *claims challenge* will be returned by the target resource,
267268
as a `claims_challenge` directive in the `www-authenticate` header,
268269
even if the app developer did not opt in for the "CP1" client capability.
269-
Upon receiving a `claims_challenge`, MSAL will skip a token cache read,
270-
and will attempt to acquire a new token.
270+
Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.
271271
272272
.. note::
273273
@@ -278,11 +278,13 @@ def acquire_token_for_client(
278278
This is a service-side behavior that cannot be changed by this library.
279279
`Azure VM docs <https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>`_
280280
"""
281+
access_token_to_refresh = None # This could become a public parameter in the future
281282
access_token_from_cache = None
282283
client_id_in_cache = self._managed_identity.get(
283284
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
284285
now = time.time()
285-
if not claims_challenge: # Then attempt token cache search
286+
if True: # Attempt cache search even if receiving claims_challenge,
287+
# because we want to locate the existing token (if any) and refresh it
286288
matches = self._token_cache.find(
287289
self._token_cache.CredentialType.ACCESS_TOKEN,
288290
target=[resource],
@@ -297,6 +299,11 @@ def acquire_token_for_client(
297299
expires_in = int(entry["expires_on"]) - now
298300
if expires_in < 5*60: # Then consider it expired
299301
continue # Removal is not necessary, it will be overwritten
302+
if claims_challenge and not access_token_to_refresh:
303+
# Since caller did not pinpoint the token causing claims challenge,
304+
# we have to assume it is the first token we found in cache.
305+
access_token_to_refresh = entry["secret"]
306+
break
300307
logger.debug("Cache hit an AT")
301308
access_token_from_cache = { # Mimic a real response
302309
"access_token": entry["secret"],
@@ -310,7 +317,13 @@ def acquire_token_for_client(
310317
break # With a fallback in hand, we break here to go refresh
311318
return access_token_from_cache # It is still good as new
312319
try:
313-
result = _obtain_token(self._http_client, self._managed_identity, resource)
320+
result = _obtain_token(
321+
self._http_client, self._managed_identity, resource,
322+
claims_challenge=claims_challenge,
323+
access_token_sha256_to_refresh=hashlib.sha256(
324+
access_token_to_refresh.encode("utf-8")).hexdigest()
325+
if access_token_to_refresh else None,
326+
)
314327
if "access_token" in result:
315328
expires_in = result.get("expires_in", 3600)
316329
if "refresh_in" not in result and expires_in >= 7200:
@@ -385,8 +398,14 @@ def get_managed_identity_source():
385398
return DEFAULT_TO_VM
386399

387400

388-
def _obtain_token(http_client, managed_identity, resource):
389-
# A unified low-level API that talks to different Managed Identity
401+
def _obtain_token(
402+
http_client, managed_identity, resource,
403+
*,
404+
claims_challenge: Optional[str] = None,
405+
access_token_sha256_to_refresh: Optional[str] = None,
406+
):
407+
if claims_challenge and len(claims_challenge) > 1024:
408+
claims_challenge = None # MSIv1 does not support long url
390409
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
391410
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
392411
):
@@ -402,6 +421,8 @@ def _obtain_token(http_client, managed_identity, resource):
402421
os.environ["IDENTITY_HEADER"],
403422
os.environ["IDENTITY_SERVER_THUMBPRINT"],
404423
resource,
424+
claims_challenge=claims_challenge,
425+
access_token_sha256_to_refresh=access_token_sha256_to_refresh,
405426
)
406427
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
407428
return _obtain_token_on_app_service(
@@ -551,6 +572,9 @@ def _obtain_token_on_machine_learning(
551572

552573
def _obtain_token_on_service_fabric(
553574
http_client, endpoint, identity_header, server_thumbprint, resource,
575+
*,
576+
claims_challenge: str = None,
577+
access_token_sha256_to_refresh: str = None,
554578
):
555579
"""Obtains token for
556580
`Service Fabric <https://learn.microsoft.com/en-us/azure/service-fabric/>`_
@@ -561,7 +585,12 @@ def _obtain_token_on_service_fabric(
561585
logger.debug("Obtaining token via managed identity on Azure Service Fabric")
562586
resp = http_client.get(
563587
endpoint,
564-
params={"api-version": "2019-07-01-preview", "resource": resource},
588+
params={k: v for k, v in {
589+
"api-version": "2019-07-01-preview",
590+
"resource": resource,
591+
"claims": claims_challenge,
592+
"token_sha256_to_refresh": access_token_sha256_to_refresh,
593+
}.items() if v is not None},
565594
headers={"Secret": identity_header},
566595
)
567596
try:

tests/test_mi.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import json
23
import os
34
import sys
@@ -79,7 +80,13 @@ def assertCacheStatus(self, app):
7980
"Should have expected client_id")
8081
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")
8182

82-
def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
83+
def _test_happy_path(
84+
self, app, mocked_http, expires_in, *, resource="R", claims_challenge=None,
85+
):
86+
"""It tests a normal token request that is expected to hit IdP,
87+
a subsequent same token request that is expected to hit cache,
88+
and then a request with claims_challenge that shall hit IdP again.
89+
"""
8390
result = app.acquire_token_for_client(resource=resource)
8491
mocked_http.assert_called()
8592
call_count = mocked_http.call_count
@@ -115,7 +122,8 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
115122
expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on,
116123
"Should have a refresh_on time around the middle of the token's life")
117124

118-
result = app.acquire_token_for_client(resource=resource, claims_challenge="foo")
125+
result = app.acquire_token_for_client(
126+
resource=resource, claims_challenge=claims_challenge or "placeholder")
119127
self.assertEqual("identity_provider", result["token_source"], "Should miss cache")
120128

121129

@@ -128,6 +136,14 @@ def test_happy_path(self):
128136
text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in,
129137
)) as mocked_method:
130138
self._test_happy_path(self.app, mocked_method, expires_in)
139+
mocked_method.assert_called_with(
140+
# The last call contained claims_challenge
141+
# but since IMDS doesn't support token_sha256_to_refresh,
142+
# the request shall remain the same as before
143+
'http://169.254.169.254/metadata/identity/oauth2/token',
144+
params={'api-version': '2018-02-01', 'resource': 'R'},
145+
headers={'Metadata': 'true'},
146+
)
131147

132148
def test_vm_error_should_be_returned_as_is(self):
133149
raw_error = '{"raw": "error format is undefined"}'
@@ -229,18 +245,52 @@ def test_machine_learning_error_should_be_normalized(self):
229245
})
230246
class ServiceFabricTestCase(ClientTestCase):
231247

232-
def _test_happy_path(self, app):
248+
def _test_happy_path(self, app, *, claims_challenge=None) -> callable:
233249
expires_in = 1234
234250
with patch.object(app._http_client, "get", return_value=MinimalResponse(
235251
status_code=200,
236252
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
237253
int(time.time()) + expires_in),
238254
)) as mocked_method:
239255
super(ServiceFabricTestCase, self)._test_happy_path(
240-
app, mocked_method, expires_in)
241-
242-
def test_happy_path(self):
243-
self._test_happy_path(self.app)
256+
app, mocked_method, expires_in, claims_challenge=claims_challenge)
257+
return mocked_method
258+
259+
def test_happy_path_with_small_claim_challenge(self):
260+
claims_challenge='{"access_token": {"nbf": {"essential": true, "value": "1563308371"}}}'
261+
mocked_method = self._test_happy_path(self.app, claims_challenge=claims_challenge)
262+
mocked_method.assert_called_with(
263+
# The last call contained claims_challenge
264+
# and the claim_challenge has a size less than 1KB,
265+
# so it should relay both claims and hash to SF
266+
'http://localhost',
267+
params={
268+
'api-version': '2019-07-01-preview',
269+
'resource': 'R',
270+
'claims': claims_challenge,
271+
'token_sha256_to_refresh': hashlib.sha256(b"AT").hexdigest(),
272+
},
273+
headers={'Secret': 'foo'}
274+
)
275+
276+
def test_happy_path_with_large_claim_challenge(self):
277+
claims_challenge=json.dumps({
278+
"filler": "x" * 1024, # 1KB payload
279+
"access_token": {"nbf": {"essential": True, "value": "1563308371"}},
280+
})
281+
mocked_method = self._test_happy_path(self.app, claims_challenge=claims_challenge)
282+
mocked_method.assert_called_with(
283+
# The last call contained claims_challenge
284+
# and the claim_challenge has a size more than 1KB,
285+
# so it should relay token_sha256_to_refresh only to SF
286+
'http://localhost',
287+
params={
288+
'api-version': '2019-07-01-preview',
289+
'resource': 'R',
290+
'token_sha256_to_refresh': hashlib.sha256(b"AT").hexdigest(),
291+
},
292+
headers={'Secret': 'foo'}
293+
)
244294

245295
def test_unified_api_service_should_ignore_unnecessary_client_id(self):
246296
self._test_happy_path(ManagedIdentityClient(

0 commit comments

Comments
 (0)