Skip to content

Commit 5971e53

Browse files
committed
Merge branch 'confidential-client' into dev
2 parents 2349f5d + 671495c commit 5971e53

File tree

2 files changed

+42
-42
lines changed

2 files changed

+42
-42
lines changed

msal/application.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from oauth2cli import Client
1010
from .authority import Authority
11-
from .assertion import create_jwt_assertion
11+
from oauth2cli.assertion import JwtSigner
1212
import mex
1313
import wstrust_request
1414
from .wstrust_response import SAML_TOKEN_TYPE_V1, SAML_TOKEN_TYPE_V2
@@ -55,7 +55,7 @@ def __init__(
5555
or an X509 certificate container in this form:
5656
5757
{
58-
"certificate": "-----BEGIN PRIVATE KEY-----...",
58+
"private_key": "...-----BEGIN PRIVATE KEY-----...",
5959
"thumbprint": "A1B2C3D4E5F6...",
6060
}
6161
"""
@@ -66,36 +66,36 @@ def __init__(
6666
validate_authority)
6767
# Here the self.authority is not the same type as authority in input
6868
self.token_cache = token_cache or TokenCache()
69-
default_body = self._build_auth_parameters(
70-
self.client_credential,
71-
self.authority.token_endpoint, self.client_id)
72-
default_body["client_info"] = 1
73-
self.client = Client(
69+
self.client = self._build_client(client_credential, self.authority)
70+
71+
def _build_client(self, client_credential, authority):
72+
client_assertion = None
73+
default_body = {"client_info": 1}
74+
if isinstance(client_credential, dict):
75+
assert ("private_key" in client_credential
76+
and "thumbprint" in client_credential)
77+
signer = JwtSigner(
78+
client_credential["private_key"], algorithm="RS256",
79+
sha1_thumbprint=client_credential.get("thumbprint"))
80+
client_assertion = signer.sign_assertion(
81+
audience=authority.token_endpoint, issuer=self.client_id)
82+
else:
83+
default_body['client_secret'] = client_credential
84+
return Client(
7485
self.client_id,
7586
configuration={
76-
"token_endpoint": self.authority.token_endpoint,
87+
"authorization_endpoint": authority.authorization_endpoint,
88+
"token_endpoint": authority.token_endpoint,
7789
"device_authorization_endpoint": urljoin(
78-
self.authority.token_endpoint, "devicecode"),
90+
authority.token_endpoint, "devicecode"),
7991
},
8092
default_body=default_body,
93+
client_assertion=client_assertion,
8194
on_obtaining_tokens=self.token_cache.add,
8295
on_removing_rt=self.token_cache.remove_rt,
8396
on_updating_rt=self.token_cache.update_rt,
8497
)
8598

86-
@staticmethod
87-
def _build_auth_parameters(client_credential, token_endpoint, client_id):
88-
if isinstance(client_credential, dict):
89-
type_ = 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer'
90-
assertion = create_jwt_assertion(
91-
client_credential['certificate'],
92-
client_credential['thumbprint'],
93-
audience=token_endpoint, issuer=client_id)
94-
return {
95-
'client_assertion_type': type_, 'client_assertion': assertion}
96-
else:
97-
return {'client_secret': client_credential}
98-
9999
def get_authorization_request_url(
100100
self,
101101
scope,
@@ -218,16 +218,7 @@ def acquire_token_silent(
218218
"home_account_id": (account or {}).get("home_account_id"),
219219
# "realm": the_authority.tenant, # AAD RTs are tenant-independent
220220
})
221-
client = Client(
222-
self.client_id,
223-
configuration={"token_endpoint": the_authority.token_endpoint},
224-
default_body=self._build_auth_parameters(
225-
self.client_credential,
226-
the_authority.token_endpoint, self.client_id),
227-
on_obtaining_tokens=self.token_cache.add,
228-
on_removing_rt=self.token_cache.remove_rt,
229-
on_updating_rt=self.token_cache.update_rt,
230-
)
221+
client = self._build_client(self.client_credential, the_authority)
231222
for entry in matches:
232223
response = client.obtain_token_with_refresh_token(
233224
entry, rt_getter=lambda token_item: token_item["secret"],

tests/test_application.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def assertCacheWorks(self, result_from_wire):
6060
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
6161
class TestConfidentialClientApplication(unittest.TestCase):
6262

63+
def assertCacheWorks(self, result_from_wire, result_from_cache):
64+
self.assertIsNotNone(result_from_cache)
65+
self.assertEqual(
66+
result_from_wire['access_token'], result_from_cache['access_token'])
67+
6368
@unittest.skipUnless("client_secret" in CONFIG, "Missing client secret")
6469
def test_confidential_client_using_secret(self):
6570
app = ConfidentialClientApplication(
@@ -68,19 +73,23 @@ def test_confidential_client_using_secret(self):
6873
scope = CONFIG.get("scope", [])
6974
result = app.acquire_token_for_client(scope)
7075
self.assertIn('access_token', result)
76+
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))
7177

72-
result_from_cache = app.acquire_token_silent(scope, account=None)
73-
self.assertIsNotNone(result_from_cache)
74-
self.assertEqual(result['access_token'], result_from_cache['access_token'])
75-
76-
@unittest.skipUnless("private_key" in CONFIG, "Missing client cert")
78+
@unittest.skipUnless("client_certificate" in CONFIG, "Missing client cert")
7779
def test_confidential_client_using_certificate(self):
78-
private_key = os.path.join(THIS_FOLDER, CONFIG['private_key'])
79-
with open(private_key) as f: pem = f.read()
80-
certificate = {'certificate': pem, "thumbprint": CONFIG['thumbprint']}
81-
app = ConfidentialClientApplication(CONFIG['client_id'], certificate)
82-
result = app.acquire_token_for_client(self.scope)
80+
client_certificate = CONFIG["client_certificate"]
81+
assert ("private_key_path" in client_certificate
82+
and "thumbprint" in client_certificate)
83+
key_path = os.path.join(THIS_FOLDER, client_certificate['private_key_path'])
84+
with open(key_path) as f:
85+
pem = f.read()
86+
app = ConfidentialClientApplication(
87+
CONFIG['client_id'],
88+
{"private_key": pem, "thumbprint": client_certificate["thumbprint"]})
89+
scope = CONFIG.get("scope", [])
90+
result = app.acquire_token_for_client(scope)
8391
self.assertIn('access_token', result)
92+
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))
8493

8594

8695
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")

0 commit comments

Comments
 (0)