Skip to content

Commit 43999a6

Browse files
authored
Merge pull request #16 from AzureAD/assertion-in-bytes-in-python3
Assertion in bytes in python3
2 parents 17948c1 + 1552449 commit 43999a6

File tree

7 files changed

+47
-30
lines changed

7 files changed

+47
-30
lines changed

msal/application.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
except: # Python 3
55
from urllib.parse import urljoin
66
import logging
7-
from base64 import b64encode
87
import sys
98

109
from .oauth2cli import Client, JwtSigner
@@ -404,9 +403,10 @@ def _acquire_token_by_username_password_federated(
404403
if not grant_type:
405404
raise RuntimeError(
406405
"RSTR returned unknown token type: %s", wstrust_result.get("type"))
406+
self.client.grant_assertion_encoders.setdefault( # Register a non-standard type
407+
grant_type, self.client.encode_saml_assertion)
407408
return self.client.obtain_token_by_assertion(
408-
b64encode(wstrust_result["token"]),
409-
grant_type=grant_type, scope=scopes, **kwargs)
409+
wstrust_result["token"], grant_type, scope=scopes, **kwargs)
410410

411411

412412
class ConfidentialClientApplication(ClientApplication): # server-side web app

msal/oauth2cli/assertion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import jwt
88

99

10-
logger = logging.getLogger(__file__)
10+
logger = logging.getLogger(__name__)
1111

1212
class Signer(object):
1313
def sign_assertion(

msal/oauth2cli/authcode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .oauth2 import Client
2222

2323

24-
logger = logging.getLogger(__file__)
24+
logger = logging.getLogger(__name__)
2525

2626
def obtain_auth_code(listen_port, auth_uri=None):
2727
"""This function will start a web server listening on http://localhost:port

msal/oauth2cli/oauth2.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import warnings
1111
import time
12+
import base64
1213

1314
import requests
1415

@@ -18,12 +19,21 @@ class BaseClient(object):
1819
# This low-level interface works. Yet you'll find its sub-class
1920
# more friendly to remind you what parameters are needed in each scenario.
2021
# More on Client Types at https://tools.ietf.org/html/rfc6749#section-2.1
22+
23+
@staticmethod
24+
def encode_saml_assertion(assertion):
25+
return base64.urlsafe_b64encode(assertion).rstrip(b'=') # Per RFC 7522
26+
27+
CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
28+
CLIENT_ASSERTION_TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
29+
client_assertion_encoders = {CLIENT_ASSERTION_TYPE_SAML2: encode_saml_assertion}
30+
2131
def __init__(
2232
self,
2333
server_configuration, # type: dict
2434
client_id, # type: str
2535
client_secret=None, # type: Optional[str]
26-
client_assertion=None, # type: Optional[str]
36+
client_assertion=None, # type: Optional[bytes]
2737
client_assertion_type=None, # type: Optional[str]
2838
default_headers=None, # type: Optional[dict]
2939
default_body=None, # type: Optional[dict]
@@ -45,14 +55,14 @@ def __init__(
4555
https://example.com/.../.well-known/openid-configuration
4656
client_id (str): The client's id, issued by the authorization server
4757
client_secret (str): Triggers HTTP AUTH for Confidential Client
48-
client_assertion (str):
58+
client_assertion (bytes):
4959
The client assertion to authenticate this client, per RFC 7521.
60+
It can be a raw SAML2 assertion (this method will encode it for you),
61+
or a raw JWT assertion.
5062
client_assertion_type (str):
51-
The format of the client_assertion.
52-
If you leave it as the default None, this method will try to make
53-
a guess between SAML2 (RFC 7522) and JWT (RFC 7523),
54-
the only two profiles defined in RFC 7521.
55-
But you can also explicitly provide a value, if needed.
63+
The type of your :attr:`client_assertion` parameter.
64+
It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or
65+
:attr:`CLIENT_ASSERTION_TYPE_JWT`, the only two defined in RFC 7521.
5666
default_headers (dict):
5767
A dict to be sent in each request header.
5868
It is not required by OAuth2 specs, but you may use it for telemetry.
@@ -66,12 +76,10 @@ def __init__(
6676
self.client_id = client_id
6777
self.client_secret = client_secret
6878
self.default_body = default_body or {}
69-
if client_assertion is not None: # See https://tools.ietf.org/html/rfc7521#section-4.2
70-
if client_assertion_type is None: # RFC7521 defines only 2 profiles
71-
TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
72-
TYPE_SAML2 = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer"
73-
client_assertion_type = TYPE_JWT if "." in client_assertion else TYPE_SAML2
74-
self.default_body["client_assertion"] = client_assertion
79+
if client_assertion is not None and client_assertion_type is not None:
80+
# See https://tools.ietf.org/html/rfc7521#section-4.2
81+
encoder = self.client_assertion_encoders.get(client_assertion_type, lambda a: a)
82+
self.default_body["client_assertion"] = encoder(client_assertion)
7583
self.default_body["client_assertion_type"] = client_assertion_type
7684
self.logger = logging.getLogger(__name__)
7785
self.session = s = requests.Session()
@@ -172,6 +180,8 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class
172180
DEVICE_FLOW_RETRIABLE_ERRORS = ("authorization_pending", "slow_down")
173181
GRANT_TYPE_SAML2 = "urn:ietf:params:oauth:grant-type:saml2-bearer" # RFC7522
174182
GRANT_TYPE_JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer" # RFC7523
183+
grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion}
184+
175185

176186
def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
177187
# type: (list, **dict) -> dict
@@ -409,22 +419,20 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
409419
raise ValueError("token_item should not be a type %s" % type(token_item))
410420

411421
def obtain_token_by_assertion(
412-
self, assertion, grant_type=None, scope=None, **kwargs):
413-
# type: (str, Union[str, None], Union[str, list, set, tuple]) -> dict
422+
self, assertion, grant_type, scope=None, **kwargs):
423+
# type: (bytes, Union[str, None], Union[str, list, set, tuple]) -> dict
414424
"""This method implements Assertion Framework for OAuth2 (RFC 7521).
415425
See details at https://tools.ietf.org/html/rfc7521#section-4.1
416426
417-
:param assertion: The assertion string which will be sent on wire as-is
427+
:param assertion:
428+
The assertion bytes can be a raw SAML2 assertion, or a JWT assertion.
418429
:param grant_type:
419-
If you leave it as the default None, this method will try to make
420-
a guess between SAML2 (RFC 7522) and JWT (RFC 7523),
421-
the only two profiles defined in RFC 7521.
422-
But you can also explicitly provide a value, if needed.
430+
It is typically either the value of :attr:`GRANT_TYPE_SAML2`,
431+
or :attr:`GRANT_TYPE_JWT`, the only two profiles defined in RFC 7521.
423432
:param scope: Optional. It must be a subset of previously granted scopes.
424433
"""
425-
if grant_type is None:
426-
grant_type = self.GRANT_TYPE_JWT if "." in assertion else self.GRANT_TYPE_SAML2
434+
encoder = self.grant_assertion_encoders.get(grant_type, lambda a: a)
427435
data = kwargs.pop("data", {})
428-
data.update(scope=scope, assertion=assertion)
436+
data.update(scope=scope, assertion=encoder(assertion))
429437
return self._obtain_token(grant_type, data=data, **kwargs)
430438

msal/token_cache.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .authority import canonicalize
88

99

10+
logger = logging.getLogger(__name__)
11+
1012
def is_subdict_of(small, big):
1113
return dict(big, **small) == big
1214

@@ -46,7 +48,13 @@ def add(self, event):
4648
# type: (dict) -> None
4749
# event typically contains: client_id, scope, token_endpoint,
4850
# resposne, params, data, grant_type
49-
logging.debug("event=%s", json.dumps(event, indent=4))
51+
for sensitive in ("password", "client_secret"):
52+
if sensitive in event.get("data", {}):
53+
# Hide them from accidental exposure in logging
54+
event["data"][sensitive] = "********"
55+
logger.debug("event=%s", json.dumps(event, indent=4, sort_keys=True,
56+
default=str, # A workaround when assertion is in bytes in Python 3
57+
))
5058
response = event.get("response", {})
5159
access_token = response.get("access_token", {})
5260
refresh_token = response.get("refresh_token", {})

msal/wstrust_request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .wstrust_response import parse_response
3737

3838

39-
logger = logging.getLogger(__file__)
39+
logger = logging.getLogger(__name__)
4040

4141
def send_request(
4242
username, password, cloud_audience_urn, endpoint_address, soap_action,

tests/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def setUpClass(cls):
9797
audience=CONFIG["openid_configuration"]["token_endpoint"],
9898
issuer=CONFIG["client_id"],
9999
),
100+
client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT,
100101
)
101102
else:
102103
cls.client = Client(

0 commit comments

Comments
 (0)