Skip to content

Commit 5b04f8d

Browse files
committed
Merge branch 'refactor-confidential-client' into dev
2 parents 4fd1b9f + afaf13b commit 5b04f8d

File tree

4 files changed

+77
-21
lines changed

4 files changed

+77
-21
lines changed

msal/application.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import requests
1111

12-
from .oauth2cli import Client, JwtSigner
12+
from .oauth2cli import Client, JwtAssertionCreator
1313
from .authority import Authority
1414
from .mex import send_request as mex_send_request
1515
from .wstrust_request import send_request as wst_send_request
@@ -154,10 +154,10 @@ def _build_client(self, client_credential, authority):
154154
headers = {}
155155
if 'public_certificate' in client_credential:
156156
headers["x5c"] = extract_certs(client_credential['public_certificate'])
157-
signer = JwtSigner(
157+
assertion = JwtAssertionCreator(
158158
client_credential["private_key"], algorithm="RS256",
159159
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
160-
client_assertion = signer.sign_assertion(
160+
client_assertion = assertion.create_regenerative_assertion(
161161
audience=authority.token_endpoint, issuer=self.client_id,
162162
additional_claims=self.client_claims or {})
163163
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT

msal/oauth2cli/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__version__ = "0.2.0"
1+
__version__ = "0.3.0"
22

33
from .oidc import Client
4-
from .assertion import JwtSigner
4+
from .assertion import JwtAssertionCreator
5+
from .assertion import JwtSigner # Obsolete. For backward compatibility.
56

msal/oauth2cli/assertion.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,57 @@
99

1010
logger = logging.getLogger(__name__)
1111

12-
class Signer(object):
13-
def sign_assertion(
14-
self, audience, issuer, subject, expires_at,
12+
class AssertionCreator(object):
13+
def create_normal_assertion(
14+
self, audience, issuer, subject, expires_at=None, expires_in=600,
1515
issued_at=None, assertion_id=None, **kwargs):
16-
# Names are defined in https://tools.ietf.org/html/rfc7521#section-5
16+
"""Create an assertion in bytes, based on the provided claims.
17+
18+
All parameter names are defined in https://tools.ietf.org/html/rfc7521#section-5
19+
except the expires_in is defined here as lifetime-in-seconds,
20+
which will be automatically translated into expires_at in UTC.
21+
"""
1722
raise NotImplementedError("Will be implemented by sub-class")
1823

24+
def create_regenerative_assertion(
25+
self, audience, issuer, subject=None, expires_in=600, **kwargs):
26+
"""Create an assertion as a callable,
27+
which will then compute the assertion later when necessary.
28+
29+
This is a useful optimization to reuse the client assertion.
30+
"""
31+
return AutoRefresher( # Returns a callable
32+
lambda a=audience, i=issuer, s=subject, e=expires_in, kwargs=kwargs:
33+
self.create_normal_assertion(a, i, s, expires_in=e, **kwargs),
34+
expires_in=max(expires_in-60, 0))
35+
36+
37+
class AutoRefresher(object):
38+
"""Cache the output of a factory, and auto-refresh it when necessary. Usage::
1939
20-
class JwtSigner(Signer):
40+
r = AutoRefresher(time.time, expires_in=5)
41+
for i in range(15):
42+
print(r()) # the timestamp change only after every 5 seconds
43+
time.sleep(1)
44+
"""
45+
def __init__(self, factory, expires_in=540):
46+
self._factory = factory
47+
self._expires_in = expires_in
48+
self._buf = {}
49+
def __call__(self):
50+
EXPIRES_AT, VALUE = "expires_at", "value"
51+
now = time.time()
52+
if self._buf.get(EXPIRES_AT, 0) <= now:
53+
logger.debug("Regenerating new assertion")
54+
self._buf = {VALUE: self._factory(), EXPIRES_AT: now + self._expires_in}
55+
else:
56+
logger.debug("Reusing still valid assertion")
57+
return self._buf.get(VALUE)
58+
59+
60+
class JwtAssertionCreator(AssertionCreator):
2161
def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None):
22-
"""Create a signer.
62+
"""Construct a Jwt assertion creator.
2363
2464
Args:
2565
@@ -37,11 +77,11 @@ def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None):
3777
self.headers["x5t"] = base64.urlsafe_b64encode(
3878
binascii.a2b_hex(sha1_thumbprint)).decode()
3979

40-
def sign_assertion(
41-
self, audience, issuer, subject=None, expires_at=None,
80+
def create_normal_assertion(
81+
self, audience, issuer, subject=None, expires_at=None, expires_in=600,
4282
issued_at=None, assertion_id=None, not_before=None,
4383
additional_claims=None, **kwargs):
44-
"""Sign a JWT Assertion.
84+
"""Create a JWT Assertion.
4585
4686
Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3
4787
Key-value pairs in additional_claims will be added into payload as-is.
@@ -51,7 +91,7 @@ def sign_assertion(
5191
'aud': audience,
5292
'iss': issuer,
5393
'sub': subject or issuer,
54-
'exp': expires_at or (now + 10*60), # 10 minutes
94+
'exp': expires_at or (now + expires_in),
5595
'iat': issued_at or now,
5696
'jti': assertion_id or str(uuid.uuid4()),
5797
}
@@ -68,3 +108,9 @@ def sign_assertion(
68108
'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional')
69109
raise
70110

111+
112+
# Obsolete. For backward compatibility. They will be removed in future versions.
113+
Signer = AssertionCreator # For backward compatibility
114+
JwtSigner = JwtAssertionCreator # For backward compatibility
115+
JwtSigner.sign_assertion = JwtAssertionCreator.create_normal_assertion # For backward compatibility
116+

msal/oauth2cli/oauth2.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
server_configuration, # type: dict
3434
client_id, # type: str
3535
client_secret=None, # type: Optional[str]
36-
client_assertion=None, # type: Optional[bytes]
36+
client_assertion=None, # type: Union[bytes, callable, None]
3737
client_assertion_type=None, # type: Optional[str]
3838
default_headers=None, # type: Optional[dict]
3939
default_body=None, # type: Optional[dict]
@@ -55,10 +55,12 @@ def __init__(
5555
https://example.com/.../.well-known/openid-configuration
5656
client_id (str): The client's id, issued by the authorization server
5757
client_secret (str): Triggers HTTP AUTH for Confidential Client
58-
client_assertion (bytes):
58+
client_assertion (bytes, callable):
5959
The client assertion to authenticate this client, per RFC 7521.
6060
It can be a raw SAML2 assertion (this method will encode it for you),
6161
or a raw JWT assertion.
62+
It can also be a callable (recommended),
63+
so that we will do lazy creation of an assertion.
6264
client_assertion_type (str):
6365
The type of your :attr:`client_assertion` parameter.
6466
It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or
@@ -75,11 +77,9 @@ def __init__(
7577
self.configuration = server_configuration
7678
self.client_id = client_id
7779
self.client_secret = client_secret
80+
self.client_assertion = client_assertion
7881
self.default_body = default_body or {}
79-
if client_assertion is not None and client_assertion_type is not None:
80-
# See https://tools.ietf.org/html/rfc7521#section-4.2
81-
encoder = self.client_assertion_encoders.get(client_assertion_type, lambda a: a)
82-
self.default_body["client_assertion"] = encoder(client_assertion)
82+
if client_assertion_type is not None:
8383
self.default_body["client_assertion_type"] = client_assertion_type
8484
self.logger = logging.getLogger(__name__)
8585
self.session = s = requests.Session()
@@ -114,6 +114,15 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
114114
**kwargs # Relay all extra parameters to underlying requests
115115
): # Returns the json object came from the OAUTH2 response
116116
_data = {'client_id': self.client_id, 'grant_type': grant_type}
117+
118+
if self.default_body.get("client_assertion_type") and self.client_assertion:
119+
# See https://tools.ietf.org/html/rfc7521#section-4.2
120+
encoder = self.client_assertion_encoders.get(
121+
self.default_body["client_assertion_type"], lambda a: a)
122+
_data["client_assertion"] = encoder(
123+
self.client_assertion() # Do lazy on-the-fly computation
124+
if callable(self.client_assertion) else self.client_assertion)
125+
117126
_data.update(self.default_body) # It may contain authen parameters
118127
_data.update(data or {}) # So the content in data param prevails
119128
# We don't have to clean up None values here, because requests lib will.

0 commit comments

Comments
 (0)