Skip to content

Commit 908c84b

Browse files
committed
Initial changes
1 parent d3464e6 commit 908c84b

File tree

4 files changed

+313
-20
lines changed

4 files changed

+313
-20
lines changed

msal/oauth2cli/authcode.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -112,26 +112,56 @@ def do_GET(self):
112112
# For flexibility, we choose to not check self.path matching redirect_uri
113113
#assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP')
114114
qs = parse_qs(urlparse(self.path).query)
115-
if qs.get('code') or qs.get("error"): # So, it is an auth response
116-
auth_response = _qs2kv(qs)
117-
logger.debug("Got auth response: %s", auth_response)
118-
if self.server.auth_state and self.server.auth_state != auth_response.get("state"):
119-
# OAuth2 successful and error responses contain state when it was used
120-
# https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1
121-
self._send_full_response("State mismatch") # Possibly an attack
122-
else:
123-
template = (self.server.success_template
124-
if "code" in qs else self.server.error_template)
125-
if _is_html(template.template):
126-
safe_data = _escape(auth_response) # Foiling an XSS attack
127-
else:
128-
safe_data = auth_response
129-
self._send_full_response(template.safe_substitute(**safe_data))
130-
self.server.auth_response = auth_response # Set it now, after the response is likely sent
115+
if qs.get('code') or qs.get("error"): # Auth response via query string is not allowed
116+
logger.error("Received auth response via query string (GET request). "
117+
"This is a security risk. Only form_post (POST) is supported.")
118+
self._send_full_response(
119+
"Authentication method not supported. "
120+
"The application requires response_mode=form_post for security. "
121+
"Please ensure your application registration uses form_post response mode.",
122+
is_ok=False)
131123
else:
132124
self._send_full_response(self.server.welcome_page)
133125
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.
134126

127+
def do_POST(self):
128+
# Handle form_post response mode where auth code is sent via POST body
129+
content_length = int(self.headers.get('Content-Length', 0))
130+
post_data = self.rfile.read(content_length).decode('utf-8')
131+
try:
132+
from urllib.parse import parse_qs as parse_qs_post
133+
except ImportError:
134+
from urlparse import parse_qs as parse_qs_post
135+
136+
qs = parse_qs_post(post_data)
137+
if qs.get('code') or qs.get('error'): # So, it is an auth response
138+
auth_response = _qs2kv(qs)
139+
logger.debug("Got auth response via POST: %s", auth_response)
140+
self._process_auth_response(auth_response)
141+
else:
142+
self._send_full_response("Invalid POST request", is_ok=False)
143+
144+
def _process_auth_response(self, auth_response):
145+
"""Process the auth response from either GET or POST request."""
146+
if self.server.auth_state and self.server.auth_state != auth_response.get("state"):
147+
# OAuth2 successful and error responses contain state when it was used
148+
# https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1
149+
self._send_full_response("State mismatch") # Possibly an attack
150+
else:
151+
template = (self.server.success_template
152+
if "code" in auth_response else self.server.error_template)
153+
if _is_html(template.template):
154+
safe_data = _escape(auth_response) # Foiling an XSS attack
155+
else:
156+
safe_data = dict(auth_response) # Make a copy to avoid mutating original
157+
# Provide default values for common OAuth2 response fields
158+
# to avoid showing literal placeholder text like "$error_description"
159+
safe_data.setdefault("error", "")
160+
safe_data.setdefault("error_description", "")
161+
safe_data.setdefault("error_uri", "")
162+
self._send_full_response(template.safe_substitute(**safe_data))
163+
self.server.auth_response = auth_response # Set it now, after the response is likely sent
164+
135165
def _send_full_response(self, body, is_ok=True):
136166
self.send_response(200 if is_ok else 400)
137167
content_type = 'text/html' if _is_html(body) else 'text/plain'

msal/oauth2cli/oauth2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,16 @@ def _build_auth_request_params(self, response_type, **kwargs):
176176
response_type = self._stringify(response_type)
177177

178178
params = {'client_id': self.client_id, 'response_type': response_type}
179-
params.update(kwargs) # Note: None values will override params
179+
# Strictly enforce form_post for security - query string is not allowed
180+
params['response_mode'] = 'form_post'
181+
if 'response_mode' in kwargs and kwargs['response_mode'] != 'form_post':
182+
import warnings
183+
warnings.warn(
184+
"response_mode='{}' is not supported for security reasons. "
185+
"Using form_post instead. Query string transmission of authorization "
186+
"codes is insecure and has been disabled.".format(kwargs['response_mode']),
187+
UserWarning)
188+
params.update({k: v for k, v in kwargs.items() if k != 'response_mode'}) # Exclude response_mode from kwargs
180189
params = {k: v for k, v in params.items() if v is not None} # clean up
181190
if params.get('scope'):
182191
params['scope'] = self._stringify(params['scope'])

tests/test_authcode.py

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,159 @@ def test_no_two_concurrent_receivers_can_listen_on_same_port(self):
2626
pass
2727

2828
def test_template_should_escape_input(self):
29+
"""Test that POST request with HTML in error is properly escaped"""
2930
with AuthCodeReceiver() as receiver:
3031
receiver._scheduled_actions = [( # Injection happens here when the port is known
3132
1, # Delay it until the receiver is activated by get_auth_response()
3233
lambda: self.assertEqual(
3334
"<html>&lt;tag&gt;foo&lt;/tag&gt;</html>",
34-
requests.get("http://localhost:{}?error=<tag>foo</tag>".format(
35-
receiver.get_port())).text,
36-
"Unsafe data in HTML should be escaped",
35+
requests.post(
36+
"http://localhost:{}".format(receiver.get_port()),
37+
data={"error": "<tag>foo</tag>"},
38+
headers={'Content-Type': 'application/x-www-form-urlencoded'}
39+
).text,
3740
))]
3841
receiver.get_auth_response( # Starts server and hang until timeout
3942
timeout=3,
4043
error_template="<html>$error</html>",
4144
)
4245

46+
def test_get_request_with_auth_code_is_rejected(self):
47+
"""Test that GET request with auth code is rejected for security"""
48+
with AuthCodeReceiver() as receiver:
49+
test_code = "test_auth_code_12345"
50+
test_state = "test_state_67890"
51+
receiver._scheduled_actions = [(
52+
1,
53+
lambda: self._verify_get_rejection(
54+
receiver.get_port(),
55+
code=test_code,
56+
state=test_state
57+
)
58+
)]
59+
result = receiver.get_auth_response(
60+
timeout=3,
61+
state=test_state,
62+
)
63+
# Should not receive auth response via GET
64+
self.assertIsNone(result)
65+
66+
def _verify_get_rejection(self, port, **params):
67+
"""Helper to verify GET requests with auth codes are rejected"""
68+
try:
69+
from urllib.parse import urlencode
70+
except ImportError:
71+
from urllib import urlencode
72+
73+
response = requests.get(
74+
"http://localhost:{}?{}".format(port, urlencode(params))
75+
)
76+
# Verify error message about unsupported method
77+
self.assertIn("not supported", response.text.lower())
78+
self.assertEqual(response.status_code, 400)
79+
80+
def test_post_request_with_auth_code(self):
81+
"""Test that POST request with auth code is handled correctly (form_post response mode)"""
82+
with AuthCodeReceiver() as receiver:
83+
test_code = "test_auth_code_12345"
84+
test_state = "test_state_67890"
85+
receiver._scheduled_actions = [(
86+
1,
87+
lambda: self._send_post_auth_response(
88+
receiver.get_port(),
89+
code=test_code,
90+
state=test_state
91+
)
92+
)]
93+
result = receiver.get_auth_response(
94+
timeout=3,
95+
state=test_state,
96+
success_template="Success: Got code $code",
97+
)
98+
self.assertIsNotNone(result)
99+
self.assertEqual(result.get("code"), test_code)
100+
self.assertEqual(result.get("state"), test_state)
101+
102+
def test_post_request_with_error(self):
103+
"""Test that POST request with error is handled correctly"""
104+
with AuthCodeReceiver() as receiver:
105+
test_error = "access_denied"
106+
test_error_description = "User denied access"
107+
receiver._scheduled_actions = [(
108+
1,
109+
lambda: self._send_post_auth_response(
110+
receiver.get_port(),
111+
error=test_error,
112+
error_description=test_error_description
113+
)
114+
)]
115+
result = receiver.get_auth_response(
116+
timeout=3,
117+
error_template="Error: $error - $error_description",
118+
)
119+
self.assertIsNotNone(result)
120+
self.assertEqual(result.get("error"), test_error)
121+
self.assertEqual(result.get("error_description"), test_error_description)
122+
123+
def test_post_request_state_mismatch(self):
124+
"""Test that POST request with mismatched state is rejected"""
125+
with AuthCodeReceiver() as receiver:
126+
test_code = "test_auth_code_12345"
127+
wrong_state = "wrong_state"
128+
expected_state = "expected_state"
129+
receiver._scheduled_actions = [(
130+
1,
131+
lambda: self._send_post_auth_response(
132+
receiver.get_port(),
133+
code=test_code,
134+
state=wrong_state
135+
)
136+
)]
137+
result = receiver.get_auth_response(
138+
timeout=3,
139+
state=expected_state,
140+
)
141+
# When state mismatches, the response should not be set
142+
self.assertIsNone(result)
143+
144+
def test_post_request_escapes_html(self):
145+
"""Test that POST request with HTML in error is properly escaped"""
146+
with AuthCodeReceiver() as receiver:
147+
malicious_error = "<script>alert('xss')</script>"
148+
receiver._scheduled_actions = [(
149+
1,
150+
lambda: self._verify_post_escaping(receiver.get_port(), malicious_error)
151+
)]
152+
receiver.get_auth_response(
153+
timeout=3,
154+
error_template="<html>$error</html>",
155+
)
156+
157+
def _send_post_auth_response(self, port, **params):
158+
"""Helper to send POST request with auth response"""
159+
try:
160+
from urllib.parse import urlencode
161+
except ImportError:
162+
from urllib import urlencode
163+
164+
response = requests.post(
165+
"http://localhost:{}".format(port),
166+
data=params,
167+
headers={'Content-Type': 'application/x-www-form-urlencoded'}
168+
)
169+
return response
170+
171+
def _verify_post_escaping(self, port, malicious_error):
172+
"""Helper to verify HTML escaping in POST requests"""
173+
response = self._send_post_auth_response(port, error=malicious_error)
174+
# Verify that the malicious script is escaped
175+
self.assertIn("&lt;script&gt;", response.text)
176+
self.assertNotIn("<script>", response.text)
177+
# Note: The escape function also escapes single quotes to &#x27;
178+
expected = "<html>&lt;script&gt;alert(&#x27;xss&#x27;)&lt;/script&gt;</html>"
179+
self.assertEqual(
180+
expected,
181+
response.text,
182+
"HTML in POST data should be escaped to prevent XSS"
183+
)
184+

tests/test_response_mode.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Tests for form_post response mode in authorization code flow"""
2+
import unittest
3+
import warnings
4+
try:
5+
from urllib.parse import urlparse, parse_qs
6+
except ImportError:
7+
from urlparse import urlparse, parse_qs
8+
9+
from msal.oauth2cli import Client
10+
11+
12+
class TestResponseMode(unittest.TestCase):
13+
"""Test response_mode parameter in authorization code flow"""
14+
15+
def setUp(self):
16+
self.client = Client(
17+
{
18+
"authorization_endpoint": "https://example.com/authorize",
19+
"token_endpoint": "https://example.com/token"
20+
},
21+
"test_client_id"
22+
)
23+
24+
def test_default_response_mode_is_form_post(self):
25+
"""Test that response_mode defaults to form_post for security"""
26+
flow = self.client.initiate_auth_code_flow(
27+
scope=["openid", "profile"],
28+
redirect_uri="http://localhost:8080"
29+
)
30+
31+
# Parse the auth_uri to check query parameters
32+
parsed = urlparse(flow["auth_uri"])
33+
params = parse_qs(parsed.query)
34+
35+
# Verify response_mode is set to form_post
36+
self.assertIn("response_mode", params)
37+
self.assertEqual(params["response_mode"][0], "form_post")
38+
39+
def test_explicit_query_mode_shows_security_warning(self):
40+
"""Test that attempting to use query mode raises a security warning"""
41+
with warnings.catch_warnings(record=True) as w:
42+
warnings.simplefilter("always")
43+
flow = self.client.initiate_auth_code_flow(
44+
scope=["openid", "profile"],
45+
redirect_uri="http://localhost:8080",
46+
response_mode="query"
47+
)
48+
49+
# Verify a warning was raised
50+
self.assertEqual(len(w), 1)
51+
self.assertTrue(issubclass(w[0].category, Warning))
52+
self.assertIn("security", str(w[0].message).lower())
53+
54+
# Verify form_post is still enforced despite explicit query request
55+
parsed = urlparse(flow["auth_uri"])
56+
params = parse_qs(parsed.query)
57+
self.assertEqual(params["response_mode"][0], "form_post")
58+
59+
def test_explicit_fragment_mode_shows_security_warning(self):
60+
"""Test that attempting to use fragment mode raises a security warning"""
61+
with warnings.catch_warnings(record=True) as w:
62+
warnings.simplefilter("always")
63+
flow = self.client.initiate_auth_code_flow(
64+
scope=["openid", "profile"],
65+
redirect_uri="http://localhost:8080",
66+
response_mode="fragment"
67+
)
68+
69+
# Verify a warning was raised
70+
self.assertEqual(len(w), 1)
71+
self.assertIn("fragment", str(w[0].message))
72+
73+
# Verify form_post is still enforced
74+
parsed = urlparse(flow["auth_uri"])
75+
params = parse_qs(parsed.query)
76+
self.assertEqual(params["response_mode"][0], "form_post")
77+
78+
def test_build_auth_request_params_enforces_form_post(self):
79+
"""Test that _build_auth_request_params enforces form_post"""
80+
params = self.client._build_auth_request_params(
81+
response_type="code",
82+
redirect_uri="http://localhost:8080",
83+
scope=["openid", "profile"],
84+
state="test_state"
85+
)
86+
87+
# Verify response_mode is form_post
88+
self.assertIn("response_mode", params)
89+
self.assertEqual(params["response_mode"], "form_post")
90+
91+
def test_build_auth_request_params_ignores_explicit_query_mode(self):
92+
"""Test that _build_auth_request_params ignores explicit query mode"""
93+
with warnings.catch_warnings(record=True) as w:
94+
warnings.simplefilter("always")
95+
params = self.client._build_auth_request_params(
96+
response_type="code",
97+
redirect_uri="http://localhost:8080",
98+
scope=["openid", "profile"],
99+
state="test_state",
100+
response_mode="query"
101+
)
102+
103+
# Verify warning was raised
104+
self.assertGreater(len(w), 0)
105+
106+
# Verify form_post is enforced despite explicit request
107+
self.assertIn("response_mode", params)
108+
self.assertEqual(params["response_mode"], "form_post")
109+
110+
111+
if __name__ == '__main__':
112+
unittest.main()

0 commit comments

Comments
 (0)