Skip to content

Commit dcb6815

Browse files
committed
Refactor req_ds_cnf to delegation_scope_key
1 parent b81411b commit dcb6815

File tree

4 files changed

+49
-10
lines changed

4 files changed

+49
-10
lines changed

msal/application.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,7 +2359,8 @@ def _acquire_token_for_client(
23592359
claims_challenge=None,
23602360
*,
23612361
delegation_constraints: Optional[list] = None,
2362-
req_ds_cnf: Optional[dict] = None,
2362+
delegation_scope_key=None, # A Cyprtography's RSAPrivateKey-like object
2363+
# TODO: Support ECC key? https://github.com/pyca/cryptography/issues/4093
23632364
**kwargs
23642365
):
23652366
if self.authority.tenant.lower() in ["common", "organizations"]:
@@ -2371,17 +2372,20 @@ def _acquire_token_for_client(
23712372
telemetry_context = self._build_telemetry_context(
23722373
self.ACQUIRE_TOKEN_FOR_CLIENT_ID, refresh_reason=refresh_reason)
23732374
client = self._regional_client or self.client
2375+
if delegation_constraints:
2376+
from .crypto import _generate_rsa_key, _convert_rsa_keys
2377+
_, jwk = _convert_rsa_keys(delegation_scope_key or _generate_rsa_key())
23742378
response = client.obtain_token_for_client(
23752379
scope=scopes, # This grant flow requires no scope decoration
23762380
headers=telemetry_context.generate_headers(),
23772381
data=dict(
23782382
kwargs.pop("data", {}),
2379-
req_ds_cnf=_build_req_cnf(req_ds_cnf) if req_ds_cnf else None,
2383+
req_ds_cnf=_build_req_cnf(jwk) if delegation_constraints else None,
23802384
claims=_merge_claims_challenge_and_capabilities(
23812385
self._client_capabilities, claims_challenge)),
23822386
**kwargs)
23832387
if (
2384-
req_ds_cnf
2388+
delegation_constraints
23852389
and not response.get("error") and not response.get("xms_ds_nonce")
23862390
):
23872391
raise ValueError("Your app shall opt in to xms_ds_cnf claim first")

msal/crypto.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from base64 import urlsafe_b64encode
2+
3+
from cryptography.hazmat.primitives.asymmetric import rsa
4+
5+
6+
def _urlsafe_b64encode(n:int, bit_size:int) -> str:
7+
return urlsafe_b64encode(n.to_bytes(length=int(bit_size/8))).decode("utf-8")
8+
9+
10+
def _to_jwk(public_key: rsa.RSAPublicKey) -> dict:
11+
numbers = public_key.public_numbers()
12+
return {
13+
"kty": "RSA",
14+
"n": _urlsafe_b64encode(numbers.n, public_key.key_size),
15+
"e": _urlsafe_b64encode(numbers.e, 24), # TODO: TBD
16+
}
17+
18+
def _convert_rsa_keys(private_key: rsa.RSAPrivateKey):
19+
return "pairs.private_bytes()", _to_jwk(private_key.public_key())
20+
21+
def _generate_rsa_key() -> rsa.RSAPrivateKey:
22+
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
23+

tests/test_crypto.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from unittest import TestCase
2+
3+
from msal.crypto import _generate_rsa_key, _convert_rsa_keys
4+
5+
6+
class CryptoTestCase(TestCase):
7+
def test_key_generation(self):
8+
key = _generate_rsa_key()
9+
_, jwk = _convert_rsa_keys(key)
10+
self.assertEqual(jwk.get("kty"), "RSA")
11+
self.assertIsNotNone(jwk.get("n") and jwk.get("e"))
12+

tests/test_e2e.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def test_user_account(self):
819819

820820

821821
class CdtTestCase(LabBasedTestCase):
822-
_JWK1 = {"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}
822+
#_JWK1 = {"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}
823823
def test_service_principal(self):
824824
"""
825825
app = get_lab_app(
@@ -835,14 +835,14 @@ def test_service_principal(self):
835835
)
836836
from http.client import HTTPConnection
837837
HTTPConnection.debuglevel = 1
838+
delegation_constraints = [
839+
{"typ": "usr", "a": "C", "target": ["constraint1", "constraint4"]},
840+
{"typ": "app", "a": "R", "target": ["constraint2", "constraint5"]},
841+
{"typ": "subscription", "a": "U", "target": ["constraint3"]},
842+
]
838843
result = app.acquire_token_for_client(
839844
[f"{app.client_id}/.default"],
840-
delegation_constraints=[
841-
{"typ": "usr", "a": "C", "target": ["constraint1", "constraint4"]},
842-
{"typ": "app", "a": "R", "target": ["constraint2", "constraint5"]},
843-
{"typ": "subscription", "a": "U", "target": ["constraint3"]},
844-
],
845-
req_ds_cnf=self._JWK1,
845+
delegation_constraints=delegation_constraints,
846846
)
847847
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
848848
result.get("error"), result.get("error_description")))

0 commit comments

Comments
 (0)