Skip to content

Commit b81411b

Browse files
committed
PoC: req_ds_cnf
1 parent a421b70 commit b81411b

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

msal/application.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import base64
12
import functools
23
import json
34
import time
45
import logging
56
import sys
67
import warnings
78
from threading import Lock
9+
from typing import Optional # Needed in Python 3.7 & 3.8
810
import os
911

1012
from .oauth2cli import Client, JwtAssertionCreator
@@ -164,6 +166,17 @@ def _preferred_browser():
164166
return None
165167

166168

169+
def _build_req_cnf(jwk:dict, remove_padding:bool = False) -> str:
170+
"""req_cnf usually requires base64url encoding.
171+
172+
https://datatracker.ietf.org/doc/html/draft-ietf-oauth-pop-key-distribution-07#section-4.2.1
173+
https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e967ebeb-9e9f-443e-857a-5208802943c2
174+
"""
175+
raw = json.dumps(jwk)
176+
encoded = base64.urlsafe_b64encode(raw.encode('utf-8')).decode('utf-8')
177+
return encoded.rstrip('=') if remove_padding else encoded
178+
179+
167180
class _ClientWithCcsRoutingInfo(Client):
168181

169182
def initiate_auth_code_flow(self, **kwargs):
@@ -2344,6 +2357,9 @@ def _acquire_token_for_client(
23442357
scopes,
23452358
refresh_reason,
23462359
claims_challenge=None,
2360+
*,
2361+
delegation_constraints: Optional[list] = None,
2362+
req_ds_cnf: Optional[dict] = None,
23472363
**kwargs
23482364
):
23492365
if self.authority.tenant.lower() in ["common", "organizations"]:
@@ -2360,9 +2376,15 @@ def _acquire_token_for_client(
23602376
headers=telemetry_context.generate_headers(),
23612377
data=dict(
23622378
kwargs.pop("data", {}),
2379+
req_ds_cnf=_build_req_cnf(req_ds_cnf) if req_ds_cnf else None,
23632380
claims=_merge_claims_challenge_and_capabilities(
23642381
self._client_capabilities, claims_challenge)),
23652382
**kwargs)
2383+
if (
2384+
req_ds_cnf
2385+
and not response.get("error") and not response.get("xms_ds_nonce")
2386+
):
2387+
raise ValueError("Your app shall opt in to xms_ds_cnf claim first")
23662388
telemetry_context.update_telemetry(response)
23672389
return response
23682390

tests/test_e2e.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tests.http_client import MinimalHttpClient, MinimalResponse
3030
from msal.oauth2cli import AuthCodeReceiver
3131
from msal.oauth2cli.oidc import decode_part
32+
from msal.application import _build_req_cnf
3233

3334
try:
3435
import pymsalruntime
@@ -791,12 +792,12 @@ def test_user_account(self):
791792
self._test_user_account()
792793

793794

794-
def _data_for_pop(key):
795-
raw_req_cnf = json.dumps({"kid": key, "xms_ksl": "sw"})
795+
def _data_for_pop(key_id):
796796
return { # Sampled from Azure CLI's plugin connectedk8s
797797
'token_type': 'pop',
798-
'key_id': key,
799-
"req_cnf": base64.urlsafe_b64encode(raw_req_cnf.encode('utf-8')).decode('utf-8').rstrip('='),
798+
'key_id': key_id,
799+
"req_cnf": _build_req_cnf(
800+
{"kid": key_id, "xms_ksl": "sw"}, remove_padding=True),
800801
# Note: Sending raw_req_cnf without base64 encoding would result in an http 500 error
801802
} # See also https://github.com/Azure/azure-cli-extensions/blob/main/src/connectedk8s/azext_connectedk8s/_clientproxyutils.py#L86-L92
802803

@@ -817,6 +818,38 @@ def test_user_account(self):
817818
self._test_user_account()
818819

819820

821+
class CdtTestCase(LabBasedTestCase):
822+
_JWK1 = {"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}
823+
def test_service_principal(self):
824+
"""
825+
app = get_lab_app(
826+
authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com"
827+
"?dc=ESTS-PUB-JPELR1-AZ1-FD000-TEST1",
828+
)
829+
"""
830+
app = msal.ConfidentialClientApplication(
831+
os.getenv("RAY_APP_CLIENT_ID"),
832+
client_credential=os.getenv("RAY_APP_CLIENT_SECRET"),
833+
authority="https://login.microsoftonline.com/msidlab4.onmicrosoft.com"
834+
"?dc=ESTS-PUB-JPELR1-AZ1-FD000-TEST1", # Accessible within AzVPN
835+
)
836+
from http.client import HTTPConnection
837+
HTTPConnection.debuglevel = 1
838+
result = app.acquire_token_for_client(
839+
[f"{app.client_id}/.default"],
840+
delegation_constraints=[
841+
{"typ": "usr", "a": "C", "target": ["constraint1", "constraint4"]},
842+
{"typ": "app", "a": "R", "target": ["constraint2", "constraint5"]},
843+
{"typ": "subscription", "a": "U", "target": ["constraint3"]},
844+
],
845+
req_ds_cnf=self._JWK1,
846+
)
847+
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
848+
result.get("error"), result.get("error_description")))
849+
print("Test case result:", result)
850+
self.assertIsNotNone(result.get("xms_ds_nonce"))
851+
852+
820853
class WorldWideTestCase(LabBasedTestCase):
821854

822855
def test_aad_managed_user(self): # Pure cloud

0 commit comments

Comments
 (0)