Skip to content

Commit 92ae1e3

Browse files
committed
WIP: ManagedIdentityClient sends claims and token_sha256_to_refresh to SF
TODO: Need to change it to accept client_capabilities and the relay it to SF via xms_cc
1 parent 0922465 commit 92ae1e3

File tree

2 files changed

+87
-13
lines changed

2 files changed

+87
-13
lines changed

msal/managed_identity.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# All rights reserved.
33
#
44
# This code is licensed under the MIT License.
5+
import hashlib
56
import json
67
import logging
78
import os
@@ -266,8 +267,7 @@ def acquire_token_for_client(
266267
and then a *claims challenge* will be returned by the target resource,
267268
as a `claims_challenge` directive in the `www-authenticate` header,
268269
even if the app developer did not opt in for the "CP1" client capability.
269-
Upon receiving a `claims_challenge`, MSAL will skip a token cache read,
270-
and will attempt to acquire a new token.
270+
Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.
271271
272272
.. note::
273273
@@ -278,11 +278,13 @@ def acquire_token_for_client(
278278
This is a service-side behavior that cannot be changed by this library.
279279
`Azure VM docs <https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>`_
280280
"""
281+
access_token_to_refresh = None # This could become a public parameter in the future
281282
access_token_from_cache = None
282283
client_id_in_cache = self._managed_identity.get(
283284
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
284285
now = time.time()
285-
if not claims_challenge: # Then attempt token cache search
286+
if True: # Attempt cache search even if receiving claims_challenge,
287+
# because we want to locate the existing token (if any) and refresh it
286288
matches = self._token_cache.find(
287289
self._token_cache.CredentialType.ACCESS_TOKEN,
288290
target=[resource],
@@ -297,6 +299,11 @@ def acquire_token_for_client(
297299
expires_in = int(entry["expires_on"]) - now
298300
if expires_in < 5*60: # Then consider it expired
299301
continue # Removal is not necessary, it will be overwritten
302+
if claims_challenge and not access_token_to_refresh:
303+
# Since caller did not pinpoint the token causing claims challenge,
304+
# we have to assume it is the first token we found in cache.
305+
access_token_to_refresh = entry["secret"]
306+
break
300307
logger.debug("Cache hit an AT")
301308
access_token_from_cache = { # Mimic a real response
302309
"access_token": entry["secret"],
@@ -310,7 +317,13 @@ def acquire_token_for_client(
310317
break # With a fallback in hand, we break here to go refresh
311318
return access_token_from_cache # It is still good as new
312319
try:
313-
result = _obtain_token(self._http_client, self._managed_identity, resource)
320+
result = _obtain_token(
321+
self._http_client, self._managed_identity, resource,
322+
claims_challenge=claims_challenge,
323+
access_token_sha256_to_refresh=hashlib.sha256(
324+
access_token_to_refresh.encode("utf-8")).hexdigest()
325+
if access_token_to_refresh else None,
326+
)
314327
if "access_token" in result:
315328
expires_in = result.get("expires_in", 3600)
316329
if "refresh_in" not in result and expires_in >= 7200:
@@ -385,8 +398,14 @@ def get_managed_identity_source():
385398
return DEFAULT_TO_VM
386399

387400

388-
def _obtain_token(http_client, managed_identity, resource):
389-
# A unified low-level API that talks to different Managed Identity
401+
def _obtain_token(
402+
http_client, managed_identity, resource,
403+
*,
404+
claims_challenge: Optional[str] = None,
405+
access_token_sha256_to_refresh: Optional[str] = None,
406+
):
407+
if claims_challenge and len(claims_challenge) > 1024:
408+
claims_challenge = None # MSIv1 does not support long url
390409
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
391410
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
392411
):
@@ -402,6 +421,8 @@ def _obtain_token(http_client, managed_identity, resource):
402421
os.environ["IDENTITY_HEADER"],
403422
os.environ["IDENTITY_SERVER_THUMBPRINT"],
404423
resource,
424+
claims_challenge=claims_challenge,
425+
access_token_sha256_to_refresh=access_token_sha256_to_refresh,
405426
)
406427
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
407428
return _obtain_token_on_app_service(
@@ -553,6 +574,9 @@ def _obtain_token_on_machine_learning(
553574

554575
def _obtain_token_on_service_fabric(
555576
http_client, endpoint, identity_header, server_thumbprint, resource,
577+
*,
578+
claims_challenge: str = None,
579+
access_token_sha256_to_refresh: str = None,
556580
):
557581
"""Obtains token for
558582
`Service Fabric <https://learn.microsoft.com/en-us/azure/service-fabric/>`_
@@ -563,7 +587,12 @@ def _obtain_token_on_service_fabric(
563587
logger.debug("Obtaining token via managed identity on Azure Service Fabric")
564588
resp = http_client.get(
565589
endpoint,
566-
params={"api-version": "2019-07-01-preview", "resource": resource},
590+
params={k: v for k, v in {
591+
"api-version": "2019-07-01-preview",
592+
"resource": resource,
593+
"claims": claims_challenge,
594+
"token_sha256_to_refresh": access_token_sha256_to_refresh,
595+
}.items() if v is not None},
567596
headers={"Secret": identity_header},
568597
)
569598
try:

tests/test_mi.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import json
23
import os
34
import sys
@@ -79,7 +80,13 @@ def assertCacheStatus(self, app):
7980
"Should have expected client_id")
8081
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")
8182

82-
def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
83+
def _test_happy_path(
84+
self, app, mocked_http, expires_in, *, resource="R", claims_challenge=None,
85+
):
86+
"""It tests a normal token request that is expected to hit IdP,
87+
a subsequent same token request that is expected to hit cache,
88+
and then a request with claims_challenge that shall hit IdP again.
89+
"""
8390
result = app.acquire_token_for_client(resource=resource)
8491
mocked_http.assert_called()
8592
call_count = mocked_http.call_count
@@ -115,7 +122,8 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
115122
expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on,
116123
"Should have a refresh_on time around the middle of the token's life")
117124

118-
result = app.acquire_token_for_client(resource=resource, claims_challenge="foo")
125+
result = app.acquire_token_for_client(
126+
resource=resource, claims_challenge=claims_challenge or "placeholder")
119127
self.assertEqual("identity_provider", result["token_source"], "Should miss cache")
120128

121129

@@ -132,6 +140,9 @@ def _test_happy_path(self) -> callable:
132140

133141
def test_happy_path_of_vm(self):
134142
self._test_happy_path().assert_called_with(
143+
# The last call contained claims_challenge
144+
# but since IMDS doesn't support token_sha256_to_refresh,
145+
# the request shall remain the same as before
135146
'http://169.254.169.254/metadata/identity/oauth2/token',
136147
params={'api-version': '2018-02-01', 'resource': 'R'},
137148
headers={'Metadata': 'true'},
@@ -245,18 +256,52 @@ def test_machine_learning_error_should_be_normalized(self):
245256
})
246257
class ServiceFabricTestCase(ClientTestCase):
247258

248-
def _test_happy_path(self, app):
259+
def _test_happy_path(self, app, *, claims_challenge=None) -> callable:
249260
expires_in = 1234
250261
with patch.object(app._http_client, "get", return_value=MinimalResponse(
251262
status_code=200,
252263
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
253264
int(time.time()) + expires_in),
254265
)) as mocked_method:
255266
super(ServiceFabricTestCase, self)._test_happy_path(
256-
app, mocked_method, expires_in)
267+
app, mocked_method, expires_in, claims_challenge=claims_challenge)
268+
return mocked_method
257269

258-
def test_happy_path(self):
259-
self._test_happy_path(self.app)
270+
def test_happy_path_with_small_claim_challenge(self):
271+
claims_challenge='{"access_token": {"nbf": {"essential": true, "value": "1563308371"}}}'
272+
last_call = self._test_happy_path(self.app, claims_challenge=claims_challenge)
273+
last_call.assert_called_with(
274+
# The last call contained claims_challenge
275+
# and the claim_challenge has a size less than 1KB,
276+
# so it should relay both claims and hash to SF
277+
'http://localhost',
278+
params={
279+
'api-version': '2019-07-01-preview',
280+
'resource': 'R',
281+
'claims': claims_challenge,
282+
'token_sha256_to_refresh': hashlib.sha256(b"AT").hexdigest(),
283+
},
284+
headers={'Secret': 'foo'}
285+
)
286+
287+
def test_happy_path_with_large_claim_challenge(self):
288+
claims_challenge=json.dumps({
289+
"filler": "x" * 1024, # 1KB payload
290+
"access_token": {"nbf": {"essential": True, "value": "1563308371"}},
291+
})
292+
last_call = self._test_happy_path(self.app, claims_challenge=claims_challenge)
293+
last_call.assert_called_with(
294+
# The last call contained claims_challenge
295+
# and the claim_challenge has a size more than 1KB,
296+
# so it should relay token_sha256_to_refresh only to SF
297+
'http://localhost',
298+
params={
299+
'api-version': '2019-07-01-preview',
300+
'resource': 'R',
301+
'token_sha256_to_refresh': hashlib.sha256(b"AT").hexdigest(),
302+
},
303+
headers={'Secret': 'foo'}
304+
)
260305

261306
def test_unified_api_service_should_ignore_unnecessary_client_id(self):
262307
self._test_happy_path(ManagedIdentityClient(

0 commit comments

Comments
 (0)