Skip to content

Commit b459c0a

Browse files
committed
CDT with bearer app token
1 parent cdbbe51 commit b459c0a

File tree

7 files changed

+238
-10
lines changed

7 files changed

+238
-10
lines changed

msal/application.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from __future__ import annotations
2+
import base64
3+
import datetime
24
import functools
35
import json
46
import time
@@ -166,6 +168,17 @@ def _preferred_browser():
166168
return None
167169

168170

171+
def _build_req_cnf(jwk:dict, remove_padding:bool = False) -> str:
172+
"""req_cnf usually requires base64url encoding.
173+
174+
https://datatracker.ietf.org/doc/html/draft-ietf-oauth-pop-key-distribution-07#section-4.2.1
175+
https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e967ebeb-9e9f-443e-857a-5208802943c2
176+
"""
177+
raw = json.dumps(jwk)
178+
encoded = base64.urlsafe_b64encode(raw.encode('utf-8')).decode('utf-8')
179+
return encoded.rstrip('=') if remove_padding else encoded
180+
181+
169182
class _ClientWithCcsRoutingInfo(Client):
170183

171184
def initiate_auth_code_flow(self, **kwargs):
@@ -232,6 +245,7 @@ class ClientApplication(object):
232245
_TOKEN_SOURCE_IDP = "identity_provider"
233246
_TOKEN_SOURCE_CACHE = "cache"
234247
_TOKEN_SOURCE_BROKER = "broker"
248+
_XMS_DS_NONCE = "xms_ds_nonce"
235249

236250
_enable_broker = False
237251
_AUTH_SCHEME_UNSUPPORTED = (
@@ -241,8 +255,17 @@ class ClientApplication(object):
241255

242256
_TOKEN_CACHE_DATA: dict[str, str] = { # field_in_data: field_in_cache
243257
"key_id": "key_id", # Some token types (SSH-certs, POP) are bound to a key
258+
"req_ds_cnf": "req_ds_cnf", # Used in CDT scenario
244259
}
245260

261+
@functools.lru_cache(maxsize=2)
262+
def __get_rsa_key(self, _bucket): # _bucket is used with lru_cache pattern
263+
from .crypto import _generate_rsa_key
264+
return _generate_rsa_key()
265+
266+
def _get_rsa_key(self, _bucket=None): # Return the same RSA key, cached for a day
267+
return self.__get_rsa_key(_bucket or datetime.date.today())
268+
246269
def __init__(
247270
self, client_id,
248271
client_credential=None, authority=None, validate_authority=True,
@@ -656,7 +679,12 @@ def __init__(
656679

657680
self._decide_broker(allow_broker, enable_pii_log)
658681
self.token_cache = token_cache or TokenCache()
659-
self.token_cache._set(data_to_at=self._TOKEN_CACHE_DATA)
682+
self.token_cache._set(
683+
data_to_at=self._TOKEN_CACHE_DATA,
684+
response_to_at={ # field_in_resp: field_in_cache
685+
"xms_ds_nonce": "xms_ds_nonce",
686+
},
687+
)
660688
self._region_configured = azure_region
661689
self._region_detected = None
662690
self.client, self._regional_client = self._build_client(
@@ -1559,6 +1587,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15591587
"expires_in": int(expires_in), # OAuth2 specs defines it as int
15601588
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
15611589
}
1590+
if self._XMS_DS_NONCE in entry: # CDT needs this
1591+
access_token_from_cache[self._XMS_DS_NONCE] = entry[
1592+
self._XMS_DS_NONCE]
15621593
if "refresh_on" in entry:
15631594
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
15641595
if int(entry["refresh_on"]) < now: # aging
@@ -2347,7 +2378,16 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
23472378
except that ``allow_broker`` parameter shall remain ``None``.
23482379
"""
23492380

2350-
def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
2381+
def acquire_token_for_client(
2382+
self,
2383+
scopes,
2384+
claims_challenge=None,
2385+
*,
2386+
delegation_constraints: Optional[list] = None,
2387+
delegation_confirmation_key=None, # A Cyprtography's RSAPrivateKey-like object
2388+
# TODO: Support ECC key? https://github.com/pyca/cryptography/issues/4093
2389+
**kwargs
2390+
):
23512391
"""Acquires token for the current confidential client, not for an end user.
23522392
23532393
Since MSAL Python 1.23, it will automatically look for token from cache,
@@ -2370,8 +2410,36 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
23702410
raise ValueError( # We choose to disallow force_refresh
23712411
"Historically, this method does not support force_refresh behavior. "
23722412
)
2373-
return _clean_up(self._acquire_token_silent_with_error(
2374-
scopes, None, claims_challenge=claims_challenge, **kwargs))
2413+
if delegation_constraints:
2414+
private_key = delegation_confirmation_key or self._get_rsa_key()
2415+
from .crypto import _convert_rsa_keys
2416+
_, jwk = _convert_rsa_keys(private_key)
2417+
result = _clean_up(self._acquire_token_silent_with_error(
2418+
scopes, None, claims_challenge=claims_challenge, data=dict(
2419+
kwargs.pop("data", {}),
2420+
req_ds_cnf=_build_req_cnf(jwk) # It is part of token cache key
2421+
if delegation_constraints else None,
2422+
),
2423+
**kwargs))
2424+
if delegation_constraints and not result.get("error"):
2425+
if not result.get(self._XMS_DS_NONCE): # Available in cached token, too
2426+
raise ValueError(
2427+
"The resource did not opt in to xms_ds_cnf claim. "
2428+
"After its opt-in, call this function again with "
2429+
"a new app object or a new delegation_confirmation_key"
2430+
# in order to invalidate the token in cache
2431+
)
2432+
import jwt # Lazy loading
2433+
cdt_envelope = jwt.encode({
2434+
"constraints": delegation_constraints,
2435+
self._XMS_DS_NONCE: result[self._XMS_DS_NONCE],
2436+
}, private_key, algorithm="PS256")
2437+
result["access_token"] = jwt.encode({
2438+
"t": result["access_token"],
2439+
"c": cdt_envelope,
2440+
}, None, algorithm=None, headers={"typ": "cdt+jwt"})
2441+
del result[self._XMS_DS_NONCE] # Caller shouldn't need to know that
2442+
return result
23752443

23762444
def _acquire_token_for_client(
23772445
self,

msal/crypto.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from cryptography.hazmat.primitives.asymmetric import rsa
2+
3+
4+
def _urlsafe_b64encode(n:int, bit_size:int) -> str:
5+
from base64 import urlsafe_b64encode
6+
return urlsafe_b64encode(n.to_bytes(
7+
length=int(bit_size/8),
8+
byteorder="big",
9+
)).decode("utf-8").rstrip("=")
10+
11+
12+
def _to_jwk(public_key: rsa.RSAPublicKey) -> dict:
13+
"""Equivalent to:
14+
15+
numbers = public_key.public_numbers()
16+
result = {
17+
"kty": "RSA",
18+
"n": _urlsafe_b64encode(numbers.n, public_key.key_size),
19+
"e": _urlsafe_b64encode(numbers.e, 24),
20+
}
21+
return result
22+
"""
23+
import jwt
24+
return jwt.get_algorithm_by_name( # PyJWT 2.5.0 https://github.com/jpadilla/pyjwt/releases/tag/2.5.0
25+
"RS256"
26+
).to_jwk(
27+
public_key,
28+
as_dict=True, # PyJWT 2.7.0 https://github.com/jpadilla/pyjwt/releases/tag/2.7.0
29+
)
30+
31+
def _convert_rsa_keys(private_key: rsa.RSAPrivateKey):
32+
return "pairs.private_bytes()", _to_jwk(private_key.public_key())
33+
34+
def _generate_rsa_key() -> rsa.RSAPrivateKey:
35+
# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/rsa/#cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key
36+
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
37+

msal/token_cache.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
import hashlib
23
import json
34
import threading
45
import time
@@ -82,6 +83,7 @@ def __init__(self):
8283
realm=None, target=None,
8384
# Note: New field(s) can be added here
8485
#key_id=None,
86+
req_ds_cnf=None,
8587
**ignored_payload_from_a_real_token:
8688
"-".join([ # Note: Could use a hash here to shorten key length
8789
home_account_id or "",
@@ -91,6 +93,13 @@ def __init__(self):
9193
realm or "",
9294
target or "",
9395
#key_id or "", # So ATs of different key_id can coexist
96+
hashlib.sha256(req_ds_cnf.encode()).hexdigest()
97+
# TODO: Could hash the entire key eventually.
98+
# But before that project, we better first
99+
# change the scope to use input scope
100+
# instead of response scope,
101+
# so that a search() can probably have O(1) hit.
102+
if req_ds_cnf else "", # CDT
94103
]).lower(),
95104
self.CredentialType.ID_TOKEN:
96105
lambda home_account_id=None, environment=None, client_id=None,

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ install_requires =
4444

4545
# MSAL does not use jwt.decode(),
4646
# therefore is insusceptible to CVE-2022-29217 so no need to bump to PyJWT 2.4+
47-
PyJWT[crypto]>=1.0.0,<3
47+
PyJWT[crypto]>=2.7.0,<3
4848

4949
# load_key_and_certificates() is available since 2.5
5050
# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/serialization/#cryptography.hazmat.primitives.serialization.pkcs12.load_key_and_certificates

tests/test_application.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ClientApplication, PublicClientApplication, ConfidentialClientApplication,
1212
_str2bytes, _merge_claims_challenge_and_capabilities,
1313
)
14+
from msal.oauth2cli.oidc import decode_part
1415
from tests import unittest
1516
from tests.test_token_cache import build_id_token, build_response
1617
from tests.http_client import MinimalHttpClient, MinimalResponse
@@ -856,3 +857,85 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
856857
)
857858
self.assertEqual(result.get("error"), "broker_error")
858859

860+
861+
class CdtTestCase(unittest.TestCase):
862+
863+
def createConstraint(self, typ: str, action: str, targets: list[str]) -> dict:
864+
return {"ver": "1.0", "typ": typ, "a": action, "target": [
865+
{"val": t} for t in targets
866+
]}
867+
868+
def test_constraint_format(self):
869+
self.assertEqual([
870+
self.createConstraint("ns:usr", "create", ["guid1", "guid2"]),
871+
self.createConstraint("ns:app", "update", ["guid3", "guid4"]),
872+
self.createConstraint("ns:subscription", "read", ["guid5", "guid6"]),
873+
], [ # Format defined in https://microsoft-my.sharepoint-df.com/:w:/p/rohitshende/EZgP9niwOvhKn-CUbj1NgG4BTZ6FSD9_16vXvsaXTiUzkg?e=j5DcQu&nav=eyJoIjoiODU5NDAyNjI4In0
874+
{"ver": "1.0", "typ": "ns:usr", "a": "create", "target": [
875+
{"val": "guid1"}, {"val": "guid2"},
876+
],
877+
},
878+
{"ver": "1.0", "typ": "ns:app", "a": "update", "target": [
879+
{"val": "guid3"}, {"val": "guid4"},
880+
],
881+
},
882+
{"ver": "1.0", "typ": "ns:subscription", "a": "read", "target": [
883+
{"val": "guid5"}, {"val": "guid6"},
884+
],
885+
},
886+
], "Constraint format is correct") # MSAL actually accepts arbitrary JSON blob
887+
888+
def assertCdt(self, result: dict, constraints: list[dict]) -> None:
889+
self.assertIsNotNone(
890+
result.get("access_token"), "Encountered {}: {}".format(
891+
result.get("error"), result.get("error_description")))
892+
_expectancy = "The return value should look like a Bearer response"
893+
self.assertEqual(result["token_type"], "Bearer", _expectancy)
894+
self.assertNotIn("xms_ds_nonce", result, _expectancy)
895+
headers = json.loads(decode_part(result["access_token"].split(".")[0]))
896+
self.assertEqual(headers.get("typ"), "cdt+jwt", "typ should be cdt+jwt")
897+
payload = json.loads(decode_part(result["access_token"].split(".")[1]))
898+
self.assertIsNotNone(payload.get("t") and payload.get("c"))
899+
cdt_envelope = json.loads(decode_part(payload["c"].split(".")[1]))
900+
self.assertIn("xms_ds_nonce", cdt_envelope)
901+
self.assertEqual(cdt_envelope["constraints"], constraints)
902+
903+
def assertAppObtainsCdt(self, client_app, scopes) -> None:
904+
constraints1 = [self.createConstraint("ns:usr", "create", ["guid1"])]
905+
result = client_app.acquire_token_for_client(
906+
scopes, delegation_constraints=constraints1,
907+
)
908+
self.assertCdt(result, constraints1)
909+
910+
constraints2 = [self.createConstraint("ns:app", "update", ["guid2"])]
911+
result = client_app.acquire_token_for_client(
912+
scopes, delegation_constraints=constraints2,
913+
)
914+
self.assertEqual(result["token_source"], "cache", "App token Should hit cache")
915+
self.assertCdt(result, constraints2)
916+
917+
result = client_app.acquire_token_for_client(
918+
scopes, delegation_constraints=constraints2,
919+
delegation_confirmation_key=client_app._get_rsa_key("new"),
920+
)
921+
self.assertEqual(
922+
result["token_source"], "identity_provider",
923+
"Different key should result in a new app token")
924+
self.assertCdt(result, constraints2)
925+
926+
@patch("msal.authority.tenant_discovery", new=Mock(return_value={
927+
"authorization_endpoint": "https://contoso.com/placeholder",
928+
"token_endpoint": "https://contoso.com/placeholder",
929+
}))
930+
def test_acquire_token_for_client_should_return_a_cdt(self):
931+
app = msal.ConfidentialClientApplication("id", client_credential="secret")
932+
with patch.object(app.http_client, "post", return_value=MinimalResponse(
933+
status_code=200, text=json.dumps({
934+
"token_type": "Bearer",
935+
"access_token": "app token",
936+
"expires_in": 3600,
937+
"xms_ds_nonce": "nonce",
938+
}))) as mocked_post:
939+
self.assertAppObtainsCdt(app, ["scope1", "scope2"])
940+
self.assertEqual(mocked_post.call_count, 2)
941+

tests/test_crypto.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from unittest import TestCase
2+
3+
from msal.crypto import _generate_rsa_key, _convert_rsa_keys
4+
5+
6+
class CryptoTestCase(TestCase):
7+
def test_key_generation(self):
8+
key = _generate_rsa_key()
9+
_, jwk = _convert_rsa_keys(key)
10+
self.assertEqual(jwk.get("kty"), "RSA")
11+
self.assertIsNotNone(jwk.get("n") and jwk.get("e"))
12+

tests/test_e2e.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727

2828
import msal
2929
from tests.http_client import MinimalHttpClient, MinimalResponse
30+
from tests.test_application import CdtTestCase
3031
from msal.oauth2cli import AuthCodeReceiver
3132
from msal.oauth2cli.oidc import decode_part
33+
from msal.application import _build_req_cnf
3234

3335
try:
3436
import pymsalruntime
@@ -533,7 +535,7 @@ def tearDownClass(cls):
533535
cls.session.close()
534536

535537
@classmethod
536-
def get_lab_app_object(cls, client_id=None, **query): # https://msidlab.com/swagger/index.html
538+
def get_lab_app_object(cls, client_id=None, **query) -> dict: # https://msidlab.com/swagger/index.html
537539
url = "https://msidlab.com/api/app/{}".format(client_id or "")
538540
resp = cls.session.get(url, params=query)
539541
result = resp.json()[0]
@@ -791,12 +793,12 @@ def test_user_account(self):
791793
self._test_user_account()
792794

793795

794-
def _data_for_pop(key):
795-
raw_req_cnf = json.dumps({"kid": key, "xms_ksl": "sw"})
796+
def _data_for_pop(key_id):
796797
return { # Sampled from Azure CLI's plugin connectedk8s
797798
'token_type': 'pop',
798-
'key_id': key,
799-
"req_cnf": base64.urlsafe_b64encode(raw_req_cnf.encode('utf-8')).decode('utf-8').rstrip('='),
799+
'key_id': key_id,
800+
"req_cnf": _build_req_cnf(
801+
{"kid": key_id, "xms_ksl": "sw"}, remove_padding=True),
800802
# Note: Sending raw_req_cnf without base64 encoding would result in an http 500 error
801803
} # See also https://github.com/Azure/azure-cli-extensions/blob/main/src/connectedk8s/azext_connectedk8s/_clientproxyutils.py#L86-L92
802804

@@ -817,6 +819,23 @@ def test_user_account(self):
817819
self._test_user_account()
818820

819821

822+
class CdtTestCase(LabBasedTestCase, CdtTestCase):
823+
def test_acquire_token_for_client_should_return_a_cdt(self):
824+
resource = self.get_lab_app_object( # This resource has opted in to CDT
825+
publicClient="no", signinAudience="AzureAdMyOrg")
826+
client_app = msal.ConfidentialClientApplication(
827+
# Any CCA can use a CDT, as long as the resource opted in for a CDT
828+
# Here we use the OBO app which is in same tenant as the resource.
829+
os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"),
830+
client_credential=os.getenv("LAB_OBO_CLIENT_SECRET"),
831+
authority="{}{}.onmicrosoft.com".format(
832+
resource["authority"],
833+
resource["labName"].lower().rstrip(".com"),
834+
),
835+
)
836+
self.assertAppObtainsCdt(client_app, [f"{resource['appId']}/.default"])
837+
838+
820839
class WorldWideTestCase(LabBasedTestCase):
821840

822841
def test_aad_managed_user(self): # Pure cloud

0 commit comments

Comments
 (0)