Skip to content

Commit 53949a9

Browse files
committed
Merge remote branch
2 parents dbf0e10 + 329ba0f commit 53949a9

File tree

4 files changed

+31
-26
lines changed

4 files changed

+31
-26
lines changed

msal/oauth2cli/assertion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import jwt
88

99

10-
logger = logging.getLogger(__file__)
10+
logger = logging.getLogger(__name__)
1111

1212
class Signer(object):
1313
def sign_assertion(

msal/oauth2cli/authcode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .oauth2 import Client
2222

2323

24-
logger = logging.getLogger(__file__)
24+
logger = logging.getLogger(__name__)
2525

2626
def obtain_auth_code(listen_port, auth_uri=None):
2727
"""This function will start a web server listening on http://localhost:port

msal/oauth2cli/oauth2.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import warnings
1111
import time
12+
import base64
1213

1314
import requests
1415

@@ -18,6 +19,15 @@ class BaseClient(object):
1819
# This low-level interface works. Yet you'll find its sub-class
1920
# more friendly to remind you what parameters are needed in each scenario.
2021
# More on Client Types at https://tools.ietf.org/html/rfc6749#section-2.1
22+
23+
@staticmethod
24+
def encode_saml_assertion(assertion):
25+
return base64.urlsafe_b64encode(assertion).rstrip(b'=') # Per RFC 7522
26+
27+
CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
28+
CLIENT_ASSERTION_TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
29+
client_assertion_encoders = {CLIENT_ASSERTION_TYPE_SAML2: encode_saml_assertion}
30+
2131
def __init__(
2232
self,
2333
server_configuration, # type: dict
@@ -47,14 +57,12 @@ def __init__(
4757
client_secret (str): Triggers HTTP AUTH for Confidential Client
4858
client_assertion (bytes):
4959
The client assertion to authenticate this client, per RFC 7521.
50-
If it is a SAML assertion, you need to encode it beforehand, by:
51-
base64.urlsafe_b64encode(assertion).strip(b'=')
60+
It can be a raw SAML2 assertion (this method will encode it for you),
61+
or a raw JWT assertion.
5262
client_assertion_type (str):
53-
The format of the client_assertion.
54-
If you leave it as the default None, this method will try to make
55-
a guess between SAML2 (RFC 7522) and JWT (RFC 7523),
56-
the only two profiles defined in RFC 7521.
57-
But you can also explicitly provide a value, if needed.
63+
The type of your :attr:`client_assertion` parameter.
64+
It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or
65+
:attr:`CLIENT_ASSERTION_TYPE_JWT`, the only two defined in RFC 7521.
5866
default_headers (dict):
5967
A dict to be sent in each request header.
6068
It is not required by OAuth2 specs, but you may use it for telemetry.
@@ -68,12 +76,10 @@ def __init__(
6876
self.client_id = client_id
6977
self.client_secret = client_secret
7078
self.default_body = default_body or {}
71-
if client_assertion is not None: # See https://tools.ietf.org/html/rfc7521#section-4.2
72-
if client_assertion_type is None: # RFC7521 defines only 2 profiles
73-
TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
74-
TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
75-
client_assertion_type = TYPE_JWT if b"." in client_assertion else TYPE_SAML2
76-
self.default_body["client_assertion"] = client_assertion
79+
if client_assertion is not None and client_assertion_type is not None:
80+
# See https://tools.ietf.org/html/rfc7521#section-4.2
81+
encoder = self.client_assertion_encoders.get(client_assertion_type, lambda a: a)
82+
self.default_body["client_assertion"] = encoder(client_assertion)
7783
self.default_body["client_assertion_type"] = client_assertion_type
7884
self.logger = logging.getLogger(__name__)
7985
self.session = s = requests.Session()
@@ -174,6 +180,8 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
174180
DEVICE_FLOW_RETRIABLE_ERRORS = ("authorization_pending", "slow_down")
175181
GRANT_TYPE_SAML2 = "urn:ietf:params:oauth:grant-type:saml2-bearer" # RFC7522
176182
GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" # RFC7523
183+
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}
184+
177185

178186
def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
179187
# type: (list, **dict) -> dict
@@ -411,24 +419,20 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
411419
raise ValueError("token_item should not be a type %s" % type(token_item))
412420

413421
def obtain_token_by_assertion(
414-
self, assertion, grant_type=None, scope=None, **kwargs):
422+
self, assertion, grant_type, scope=None, **kwargs):
415423
# type: (bytes, Union[str, None], Union[str, list, set, tuple]) -> dict
416424
"""This method implements Assertion Framework for OAuth2 (RFC 7521).
417425
See details at https://tools.ietf.org/html/rfc7521#section-4.1
418426
419-
:param assertion: The assertion bytes which will be sent on wire as-is.
420-
If it is a SAML assertion, you need to encode it beforehand, by:
421-
base64.urlsafe_b64encode(assertion).strip(b'=')
427+
:param assertion:
428+
The assertion bytes can be a raw SAML2 assertion, or a JWT assertion.
422429
:param grant_type:
423-
If you leave it as the default None, this method will try to make
424-
a guess between SAML2 (RFC 7522) and JWT (RFC 7523),
425-
the only two profiles defined in RFC 7521.
426-
But you can also explicitly provide a value, if needed.
430+
It is typically either the value of :attr:`GRANT_TYPE_SAML2`,
431+
or :attr:`GRANT_TYPE_JWT`, the only two profiles defined in RFC 7521.
427432
:param scope: Optional. It must be a subset of previously granted scopes.
428433
"""
429-
if grant_type is None:
430-
grant_type = self.GRANT_TYPE_JWT if b"." in assertion else self.GRANT_TYPE_SAML2
434+
encoder = self.grant_assertion_encoders.get(grant_type, lambda a: a)
431435
data = kwargs.pop("data", {})
432-
data.update(scope=scope, assertion=assertion)
436+
data.update(scope=scope, assertion=encoder(assertion))
433437
return self._obtain_token(grant_type, data=data, **kwargs)
434438

tests/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def setUpClass(cls):
9898
audience=CONFIG["openid_configuration"]["token_endpoint"],
9999
issuer=CONFIG["client_id"],
100100
),
101+
client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT,
101102
)
102103
else:
103104
cls.client = Client(

0 commit comments

Comments
 (0)