Skip to content

Commit bf0d143

Browse files
authored
Merge pull request #515 from escattone/PKCE-in-session-refresh-middleware
add PKCE to SessionRefresh middleware
2 parents f75ff62 + 8bf691f commit bf0d143

File tree

2 files changed

+151
-1
lines changed

2 files changed

+151
-1
lines changed

mozilla_django_oidc/middleware.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from mozilla_django_oidc.utils import (
1616
absolutify,
1717
add_state_and_verifier_and_nonce_to_session,
18+
generate_code_challenge,
1819
import_from_settings,
1920
)
2021

@@ -152,7 +153,35 @@ def process_request(self, request):
152153
nonce = get_random_string(self.OIDC_NONCE_SIZE)
153154
params.update({"nonce": nonce})
154155

155-
add_state_and_verifier_and_nonce_to_session(request, state, params)
156+
if self.get_settings("OIDC_USE_PKCE", False):
157+
code_verifier_length = self.get_settings("OIDC_PKCE_CODE_VERIFIER_SIZE", 64)
158+
# Check that code_verifier_length is between the min and max length
159+
# defined in https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
160+
if not (43 <= code_verifier_length <= 128):
161+
raise ValueError("code_verifier_length must be between 43 and 128")
162+
163+
# Generate code_verifier and code_challenge pair
164+
code_verifier = get_random_string(code_verifier_length)
165+
code_challenge_method = self.get_settings(
166+
"OIDC_PKCE_CODE_CHALLENGE_METHOD", "S256"
167+
)
168+
code_challenge = generate_code_challenge(
169+
code_verifier, code_challenge_method
170+
)
171+
172+
# Append code_challenge to authentication request parameters
173+
params.update(
174+
{
175+
"code_challenge": code_challenge,
176+
"code_challenge_method": code_challenge_method,
177+
}
178+
)
179+
else:
180+
code_verifier = None
181+
182+
add_state_and_verifier_and_nonce_to_session(
183+
request, state, params, code_verifier
184+
)
156185

157186
request.session["oidc_login_next"] = request.get_full_path()
158187

tests/test_middleware.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,36 @@ def test_is_ajax(self, mock_middleware_random):
7878
json_payload = json.loads(response.content.decode("utf-8"))
7979
self.assertEqual(json_payload["refresh_url"], response["refresh_url"])
8080

81+
@override_settings(OIDC_USE_PKCE=True)
82+
def test_is_ajax_with_pkce(self, mock_middleware_random):
83+
mock_middleware_random.return_value = "examplestring"
84+
85+
request = self.factory.get("/foo", HTTP_X_REQUESTED_WITH="XMLHttpRequest")
86+
request.session = {}
87+
request.user = self.user
88+
89+
response = self.middleware.process_request(request)
90+
self.assertEqual(response.status_code, 403)
91+
# The URL to go to is available both as a header and as a key
92+
# in the JSON response.
93+
self.assertTrue(response["refresh_url"])
94+
url, qs = response["refresh_url"].split("?")
95+
self.assertEqual(url, "http://example.com/authorize")
96+
expected_query = {
97+
"response_type": ["code"],
98+
"redirect_uri": ["http://testserver/callback/"],
99+
"client_id": ["foo"],
100+
"nonce": ["examplestring"],
101+
"prompt": ["none"],
102+
"scope": ["openid email"],
103+
"state": ["examplestring"],
104+
"code_challenge_method": ["S256"],
105+
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
106+
}
107+
self.assertEqual(expected_query, parse_qs(qs))
108+
json_payload = json.loads(response.content.decode("utf-8"))
109+
self.assertEqual(json_payload["refresh_url"], response["refresh_url"])
110+
81111
def test_no_oidc_token_expiration_forces_renewal(self, mock_middleware_random):
82112
mock_middleware_random.return_value = "examplestring"
83113

@@ -101,6 +131,34 @@ def test_no_oidc_token_expiration_forces_renewal(self, mock_middleware_random):
101131
}
102132
self.assertEqual(expected_query, parse_qs(qs))
103133

134+
@override_settings(OIDC_USE_PKCE=True)
135+
def test_no_oidc_token_expiration_forces_renewal_with_pkce(
136+
self, mock_middleware_random
137+
):
138+
mock_middleware_random.return_value = "examplestring"
139+
140+
request = self.factory.get("/foo")
141+
request.user = self.user
142+
request.session = {}
143+
144+
response = self.middleware.process_request(request)
145+
146+
self.assertEqual(response.status_code, 302)
147+
url, qs = response.url.split("?")
148+
self.assertEqual(url, "http://example.com/authorize")
149+
expected_query = {
150+
"response_type": ["code"],
151+
"redirect_uri": ["http://testserver/callback/"],
152+
"client_id": ["foo"],
153+
"nonce": ["examplestring"],
154+
"prompt": ["none"],
155+
"scope": ["openid email"],
156+
"state": ["examplestring"],
157+
"code_challenge_method": ["S256"],
158+
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
159+
}
160+
self.assertEqual(expected_query, parse_qs(qs))
161+
104162
def test_expired_token_forces_renewal(self, mock_middleware_random):
105163
mock_middleware_random.return_value = "examplestring"
106164

@@ -124,6 +182,32 @@ def test_expired_token_forces_renewal(self, mock_middleware_random):
124182
}
125183
self.assertEqual(expected_query, parse_qs(qs))
126184

185+
@override_settings(OIDC_USE_PKCE=True)
186+
def test_expired_token_forces_renewal_with_pkce(self, mock_middleware_random):
187+
mock_middleware_random.return_value = "examplestring"
188+
189+
request = self.factory.get("/foo")
190+
request.user = self.user
191+
request.session = {"oidc_id_token_expiration": time.time() - 10}
192+
193+
response = self.middleware.process_request(request)
194+
195+
self.assertEqual(response.status_code, 302)
196+
url, qs = response.url.split("?")
197+
self.assertEqual(url, "http://example.com/authorize")
198+
expected_query = {
199+
"response_type": ["code"],
200+
"redirect_uri": ["http://testserver/callback/"],
201+
"client_id": ["foo"],
202+
"nonce": ["examplestring"],
203+
"prompt": ["none"],
204+
"scope": ["openid email"],
205+
"state": ["examplestring"],
206+
"code_challenge_method": ["S256"],
207+
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
208+
}
209+
self.assertEqual(expected_query, parse_qs(qs))
210+
127211

128212
# This adds a "home page" we can test against.
129213
def fakeview(req):
@@ -306,6 +390,43 @@ def test_expired_token_redirects_to_sso(self, mock_middleware_random):
306390
}
307391
self.assertEqual(expected_query, parse_qs(qs))
308392

393+
@override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT="http://example.com/authorize")
394+
@override_settings(OIDC_RP_CLIENT_ID="foo")
395+
@override_settings(OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS=120)
396+
@override_settings(OIDC_USE_PKCE=True)
397+
@patch("mozilla_django_oidc.middleware.get_random_string")
398+
def test_expired_token_redirects_to_sso_with_pkce(self, mock_middleware_random):
399+
mock_middleware_random.return_value = "examplestring"
400+
401+
client = ClientWithUser()
402+
client.login(username=self.user.username, password="password")
403+
404+
# Set expiration to some time in the past
405+
session = client.session
406+
session["oidc_id_token_expiration"] = time.time() - 100
407+
session[
408+
"_auth_user_backend"
409+
] = "mozilla_django_oidc.auth.OIDCAuthenticationBackend"
410+
session.save()
411+
412+
resp = client.get("/mdo_fake_view/")
413+
self.assertEqual(resp.status_code, 302)
414+
415+
url, qs = resp.url.split("?")
416+
self.assertEqual(url, "http://example.com/authorize")
417+
expected_query = {
418+
"response_type": ["code"],
419+
"redirect_uri": ["http://testserver/callback/"],
420+
"client_id": ["foo"],
421+
"nonce": ["examplestring"],
422+
"prompt": ["none"],
423+
"scope": ["openid email"],
424+
"state": ["examplestring"],
425+
"code_challenge_method": ["S256"],
426+
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
427+
}
428+
self.assertEqual(expected_query, parse_qs(qs))
429+
309430
@override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT="http://example.com/authorize")
310431
@override_settings(OIDC_RP_CLIENT_ID="foo")
311432
@override_settings(OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS=120)

0 commit comments

Comments
 (0)