Skip to content

Commit 94fe489

Browse files
committed
PoC: req_ds_cnf
Refactor req_ds_cnf to delegation_scope_key Move logic to allow token cache to work wip
1 parent 6d80cc5 commit 94fe489

File tree

6 files changed

+220
-9
lines changed

6 files changed

+220
-9
lines changed

msal/application.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import base64
2+
import datetime
13
import functools
24
import json
35
import time
@@ -165,6 +167,17 @@ def _preferred_browser():
165167
return None
166168

167169

170+
def _build_req_cnf(jwk:dict, remove_padding:bool = False) -> str:
171+
"""req_cnf usually requires base64url encoding.
172+
173+
https://datatracker.ietf.org/doc/html/draft-ietf-oauth-pop-key-distribution-07#section-4.2.1
174+
https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e967ebeb-9e9f-443e-857a-5208802943c2
175+
"""
176+
raw = json.dumps(jwk)
177+
encoded = base64.urlsafe_b64encode(raw.encode('utf-8')).decode('utf-8')
178+
return encoded.rstrip('=') if remove_padding else encoded
179+
180+
168181
class _ClientWithCcsRoutingInfo(Client):
169182

170183
def initiate_auth_code_flow(self, **kwargs):
@@ -231,13 +244,22 @@ class ClientApplication(object):
231244
_TOKEN_SOURCE_IDP = "identity_provider"
232245
_TOKEN_SOURCE_CACHE = "cache"
233246
_TOKEN_SOURCE_BROKER = "broker"
247+
_XMS_DS_NONCE = "xms_ds_nonce"
234248

235249
_enable_broker = False
236250
_AUTH_SCHEME_UNSUPPORTED = (
237251
"auth_scheme is currently only available from broker. "
238252
"You can enable broker by following these instructions. "
239253
"https://msal-python.readthedocs.io/en/latest/#publicclientapplication")
240254

255+
@functools.lru_cache(maxsize=2)
256+
def __get_rsa_key(self, _bucket): # _bucket is used with lru_cache pattern
257+
from .crypto import _generate_rsa_key
258+
return _generate_rsa_key()
259+
260+
def _get_rsa_key(self, _bucket=None): # Return the same RSA key, cached for a day
261+
return self.__get_rsa_key(_bucket or datetime.date.today())
262+
241263
def __init__(
242264
self, client_id,
243265
client_credential=None, authority=None, validate_authority=True,
@@ -1552,6 +1574,9 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15521574
"expires_in": int(expires_in), # OAuth2 specs defines it as int
15531575
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
15541576
}
1577+
if self._XMS_DS_NONCE in entry: # CDT needs this
1578+
access_token_from_cache[self._XMS_DS_NONCE] = entry[
1579+
self._XMS_DS_NONCE]
15551580
if "refresh_on" in entry:
15561581
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
15571582
if int(entry["refresh_on"]) < now: # aging
@@ -2340,7 +2365,16 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
23402365
except that ``allow_broker`` parameter shall remain ``None``.
23412366
"""
23422367

2343-
def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
2368+
def acquire_token_for_client(
2369+
self,
2370+
scopes,
2371+
claims_challenge=None,
2372+
*,
2373+
delegation_constraints: Optional[list] = None,
2374+
delegation_confirmation_key=None, # A Cyprtography's RSAPrivateKey-like object
2375+
# TODO: Support ECC key? https://github.com/pyca/cryptography/issues/4093
2376+
**kwargs
2377+
):
23442378
"""Acquires token for the current confidential client, not for an end user.
23452379
23462380
Since MSAL Python 1.23, it will automatically look for token from cache,
@@ -2363,8 +2397,36 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
23632397
raise ValueError( # We choose to disallow force_refresh
23642398
"Historically, this method does not support force_refresh behavior. "
23652399
)
2366-
return _clean_up(self._acquire_token_silent_with_error(
2367-
scopes, None, claims_challenge=claims_challenge, **kwargs))
2400+
if delegation_constraints:
2401+
private_key = delegation_confirmation_key or self._get_rsa_key()
2402+
from .crypto import _convert_rsa_keys
2403+
_, jwk = _convert_rsa_keys(private_key)
2404+
result = _clean_up(self._acquire_token_silent_with_error(
2405+
scopes, None, claims_challenge=claims_challenge, data=dict(
2406+
kwargs.pop("data", {}),
2407+
req_ds_cnf=_build_req_cnf(jwk) # It is part of token cache key
2408+
if delegation_constraints else None,
2409+
),
2410+
**kwargs))
2411+
if delegation_constraints and not result.get("error"):
2412+
if not result.get(self._XMS_DS_NONCE): # Available in cached token, too
2413+
raise ValueError(
2414+
"The resource did not opt in to xms_ds_cnf claim. "
2415+
"After its opt-in, call this function again with "
2416+
"a new app object or a new delegation_confirmation_key"
2417+
# in order to invalidate the token in cache
2418+
)
2419+
import jwt # Lazy loading
2420+
cdt_envelope = jwt.encode({
2421+
"constraints": delegation_constraints,
2422+
self._XMS_DS_NONCE: result[self._XMS_DS_NONCE],
2423+
}, private_key, algorithm="PS256")
2424+
result["access_token"] = jwt.encode({
2425+
"t": result["access_token"],
2426+
"c": cdt_envelope,
2427+
}, None, algorithm=None, headers={"typ": "cdt+jwt"})
2428+
del result[self._XMS_DS_NONCE] # Caller shouldn't need to know that
2429+
return result
23682430

23692431
def _acquire_token_for_client(
23702432
self,

msal/crypto.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from base64 import urlsafe_b64encode
2+
3+
from cryptography.hazmat.primitives.asymmetric import rsa
4+
5+
6+
def _urlsafe_b64encode(n:int, bit_size:int) -> str:
7+
return urlsafe_b64encode(n.to_bytes(length=int(bit_size/8))).decode("utf-8")
8+
9+
10+
def _to_jwk(public_key: rsa.RSAPublicKey) -> dict:
11+
numbers = public_key.public_numbers()
12+
return {
13+
"kty": "RSA",
14+
"n": _urlsafe_b64encode(numbers.n, public_key.key_size),
15+
"e": _urlsafe_b64encode(numbers.e, 24), # TODO: TBD. PyJWT/jwt/algorithms.py RSAAlgorithm.to_jwk()
16+
}
17+
18+
def _convert_rsa_keys(private_key: rsa.RSAPrivateKey):
19+
return "pairs.private_bytes()", _to_jwk(private_key.public_key())
20+
21+
def _generate_rsa_key() -> rsa.RSAPrivateKey:
22+
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
23+

msal/token_cache.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import json
1+
import hashlib
2+
import json
23
import threading
34
import time
45
import logging
@@ -61,6 +62,7 @@ def __init__(self):
6162
realm=None, target=None,
6263
# Note: New field(s) can be added here
6364
#key_id=None,
65+
req_ds_cnf=None,
6466
**ignored_payload_from_a_real_token:
6567
"-".join([ # Note: Could use a hash here to shorten key length
6668
home_account_id or "",
@@ -70,6 +72,13 @@ def __init__(self):
7072
realm or "",
7173
target or "",
7274
#key_id or "", # So ATs of different key_id can coexist
75+
hashlib.sha256(req_ds_cnf.encode()).hexdigest()
76+
# TODO: Could hash the entire key eventually.
77+
# But before that project, we better first
78+
# change the scope to use input scope
79+
# instead of response scope,
80+
# so that a search() can probably have O(1) hit.
81+
if req_ds_cnf else "", # CDT
7382
]).lower(),
7483
self.CredentialType.ID_TOKEN:
7584
lambda home_account_id=None, environment=None, client_id=None,
@@ -267,10 +276,13 @@ def __add(self, event, now=None):
267276
"expires_on": str(now + expires_in), # Same here
268277
"extended_expires_on": str(now + ext_expires_in) # Same here
269278
}
279+
if response.get("xms_ds_nonce"): # Available for CDT
280+
at["xms_ds_nonce"] = response["xms_ds_nonce"]
270281
at.update({k: data[k] for k in data if k in {
271282
# Also store extra data which we explicitly allow
272283
# So that we won't accidentally store a user's password etc.
273284
"key_id", # It happens in SSH-cert or POP scenario
285+
"req_ds_cnf", # Used in CDT
274286
}})
275287
if "refresh_in" in response:
276288
refresh_in = response["refresh_in"] # It is an integer

tests/test_application.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ClientApplication, PublicClientApplication, ConfidentialClientApplication,
1212
_str2bytes, _merge_claims_challenge_and_capabilities,
1313
)
14+
from msal.oauth2cli.oidc import decode_part
1415
from tests import unittest
1516
from tests.test_token_cache import build_id_token, build_response
1617
from tests.http_client import MinimalHttpClient, MinimalResponse
@@ -856,3 +857,85 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
856857
)
857858
self.assertEqual(result.get("error"), "broker_error")
858859

860+
861+
class CdtTestCase(unittest.TestCase):
862+
863+
def createConstraint(self, typ: str, action: str, targets: list[str]) -> dict:
864+
return {"ver": "1.0", "typ": typ, "a": action, "target": [
865+
{"val": t} for t in targets
866+
]}
867+
868+
def test_constraint_format(self):
869+
self.assertEqual([
870+
self.createConstraint("ns:usr", "create", ["guid1", "guid2"]),
871+
self.createConstraint("ns:app", "update", ["guid3", "guid4"]),
872+
self.createConstraint("ns:subscription", "read", ["guid5", "guid6"]),
873+
], [ # Format defined in https://microsoft-my.sharepoint-df.com/:w:/p/rohitshende/EZgP9niwOvhKn-CUbj1NgG4BTZ6FSD9_16vXvsaXTiUzkg?e=j5DcQu&nav=eyJoIjoiODU5NDAyNjI4In0
874+
{"ver": "1.0", "typ": "ns:usr", "a": "create", "target": [
875+
{"val": "guid1"}, {"val": "guid2"},
876+
],
877+
},
878+
{"ver": "1.0", "typ": "ns:app", "a": "update", "target": [
879+
{"val": "guid3"}, {"val": "guid4"},
880+
],
881+
},
882+
{"ver": "1.0", "typ": "ns:subscription", "a": "read", "target": [
883+
{"val": "guid5"}, {"val": "guid6"},
884+
],
885+
},
886+
], "Constraint format is correct") # MSAL actually accepts arbitrary JSON blob
887+
888+
def assertCdt(self, result: dict, constraints: list[dict]) -> None:
889+
self.assertIsNotNone(
890+
result.get("access_token"), "Encountered {}: {}".format(
891+
result.get("error"), result.get("error_description")))
892+
_expectancy = "The return value should look like a Bearer response"
893+
self.assertEqual(result["token_type"], "Bearer", _expectancy)
894+
self.assertNotIn("xms_ds_nonce", result, _expectancy)
895+
headers = json.loads(decode_part(result["access_token"].split(".")[0]))
896+
self.assertEqual(headers.get("typ"), "cdt+jwt", "typ should be cdt+jwt")
897+
payload = json.loads(decode_part(result["access_token"].split(".")[1]))
898+
self.assertIsNotNone(payload.get("t") and payload.get("c"))
899+
cdt_envelope = json.loads(decode_part(payload["c"].split(".")[1]))
900+
self.assertIn("xms_ds_nonce", cdt_envelope)
901+
self.assertEqual(cdt_envelope["constraints"], constraints)
902+
903+
def assertAppObtainsCdt(self, client_app, scopes) -> None:
904+
constraints1 = [self.createConstraint("ns:usr", "create", ["guid1"])]
905+
result = client_app.acquire_token_for_client(
906+
scopes, delegation_constraints=constraints1,
907+
)
908+
self.assertCdt(result, constraints1)
909+
910+
constraints2 = [self.createConstraint("ns:app", "update", ["guid2"])]
911+
result = client_app.acquire_token_for_client(
912+
scopes, delegation_constraints=constraints2,
913+
)
914+
self.assertEqual(result["token_source"], "cache", "App token Should hit cache")
915+
self.assertCdt(result, constraints2)
916+
917+
result = client_app.acquire_token_for_client(
918+
scopes, delegation_constraints=constraints2,
919+
delegation_confirmation_key=client_app._get_rsa_key("new"),
920+
)
921+
self.assertEqual(
922+
result["token_source"], "identity_provider",
923+
"Different key should result in a new app token")
924+
self.assertCdt(result, constraints2)
925+
926+
@patch("msal.authority.tenant_discovery", new=Mock(return_value={
927+
"authorization_endpoint": "https://contoso.com/placeholder",
928+
"token_endpoint": "https://contoso.com/placeholder",
929+
}))
930+
def test_acquire_token_for_client_should_return_a_cdt(self):
931+
app = msal.ConfidentialClientApplication("id", client_credential="secret")
932+
with patch.object(app.http_client, "post", return_value=MinimalResponse(
933+
status_code=200, text=json.dumps({
934+
"token_type": "Bearer",
935+
"access_token": "app token",
936+
"expires_in": 3600,
937+
"xms_ds_nonce": "nonce",
938+
}))) as mocked_post:
939+
self.assertAppObtainsCdt(app, ["scope1", "scope2"])
940+
mocked_post.assert_called_once()
941+

tests/test_crypto.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from unittest import TestCase
2+
3+
from msal.crypto import _generate_rsa_key, _convert_rsa_keys
4+
5+
6+
class CryptoTestCase(TestCase):
7+
def test_key_generation(self):
8+
key = _generate_rsa_key()
9+
_, jwk = _convert_rsa_keys(key)
10+
self.assertEqual(jwk.get("kty"), "RSA")
11+
self.assertIsNotNone(jwk.get("n") and jwk.get("e"))
12+

tests/test_e2e.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727

2828
import msal
2929
from tests.http_client import MinimalHttpClient, MinimalResponse
30+
from tests.test_application import CdtTestCase
3031
from msal.oauth2cli import AuthCodeReceiver
3132
from msal.oauth2cli.oidc import decode_part
33+
from msal.application import _build_req_cnf
3234

3335
try:
3436
import pymsalruntime
@@ -533,7 +535,7 @@ def tearDownClass(cls):
533535
cls.session.close()
534536

535537
@classmethod
536-
def get_lab_app_object(cls, client_id=None, **query): # https://msidlab.com/swagger/index.html
538+
def get_lab_app_object(cls, client_id=None, **query) -> dict: # https://msidlab.com/swagger/index.html
537539
url = "https://msidlab.com/api/app/{}".format(client_id or "")
538540
resp = cls.session.get(url, params=query)
539541
result = resp.json()[0]
@@ -791,12 +793,12 @@ def test_user_account(self):
791793
self._test_user_account()
792794

793795

794-
def _data_for_pop(key):
795-
raw_req_cnf = json.dumps({"kid": key, "xms_ksl": "sw"})
796+
def _data_for_pop(key_id):
796797
return { # Sampled from Azure CLI's plugin connectedk8s
797798
'token_type': 'pop',
798-
'key_id': key,
799-
"req_cnf": base64.urlsafe_b64encode(raw_req_cnf.encode('utf-8')).decode('utf-8').rstrip('='),
799+
'key_id': key_id,
800+
"req_cnf": _build_req_cnf(
801+
{"kid": key_id, "xms_ksl": "sw"}, remove_padding=True),
800802
# Note: Sending raw_req_cnf without base64 encoding would result in an http 500 error
801803
} # See also https://github.com/Azure/azure-cli-extensions/blob/main/src/connectedk8s/azext_connectedk8s/_clientproxyutils.py#L86-L92
802804

@@ -817,6 +819,23 @@ def test_user_account(self):
817819
self._test_user_account()
818820

819821

822+
class CdtTestCase(LabBasedTestCase, CdtTestCase):
823+
def test_acquire_token_for_client_should_return_a_cdt(self):
824+
resource = self.get_lab_app_object( # This resource has opted in to CDT
825+
publicClient="no", signinAudience="AzureAdMyOrg")
826+
client_app = msal.ConfidentialClientApplication(
827+
# Any CCA can use a CDT, as long as the resource opted in for a CDT
828+
# Here we use the OBO app which is in same tenant as the resource.
829+
os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"),
830+
client_credential=os.getenv("LAB_OBO_CLIENT_SECRET"),
831+
authority="{}{}.onmicrosoft.com".format(
832+
resource["authority"],
833+
resource["labName"].lower().rstrip(".com"),
834+
),
835+
)
836+
self.assertAppObtainsCdt(client_app, [f"{resource['appId']}/.default"])
837+
838+
820839
class WorldWideTestCase(LabBasedTestCase):
821840

822841
def test_aad_managed_user(self): # Pure cloud

0 commit comments

Comments
 (0)