Skip to content

Commit dca2001

Browse files
committed
Add test for session state on authorization view
1 parent f587442 commit dca2001

File tree

6 files changed

+58
-54
lines changed

6 files changed

+58
-54
lines changed

oauth2_provider/settings.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,7 @@ def import_from_string(val, setting_name):
173173
try:
174174
return import_string(val)
175175
except ImportError as e:
176-
msg = "Could not import %r for setting %r. %s: %s." % (
177-
val,
178-
setting_name,
179-
e.__class__.__name__,
180-
e,
181-
)
176+
msg = "Could not import %r for setting %r. %s: %s." % (val, setting_name, e.__class__.__name__, e)
182177
raise ImportError(msg)
183178

184179

oauth2_provider/urls.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717
management_urlpatterns = [
1818
# Application management views
1919
path("applications/", views.ApplicationList.as_view(), name="list"),
20-
path(
21-
"applications/register/",
22-
views.ApplicationRegistration.as_view(),
23-
name="register",
24-
),
20+
path("applications/register/", views.ApplicationRegistration.as_view(), name="register"),
2521
path("applications/<slug:pk>/", views.ApplicationDetail.as_view(), name="detail"),
2622
path("applications/<slug:pk>/delete/", views.ApplicationDelete.as_view(), name="delete"),
2723
path("applications/<slug:pk>/update/", views.ApplicationUpdate.as_view(), name="update"),

oauth2_provider/views/base.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,7 @@ def form_valid(self, form):
137137

138138
try:
139139
uri, headers, body, status = self.create_authorization_response(
140-
request=self.request,
141-
scopes=scopes,
142-
credentials=credentials,
143-
allow=allow,
140+
request=self.request, scopes=scopes, credentials=credentials, allow=allow
144141
)
145142
except OAuthToolkitError as error:
146143
return self.error_response(error, application)
@@ -160,7 +157,7 @@ def form_valid(self, form):
160157
salt = secrets.token_urlsafe(16)
161158
encoded = " ".join(
162159
[
163-
self.client.client_id,
160+
credentials["client_id"],
164161
client_origin,
165162
session_management_state_key(self.request),
166163
salt,
@@ -231,20 +228,15 @@ def get(self, request, *args, **kwargs):
231228
# are already approved.
232229
if application.skip_authorization:
233230
uri, headers, body, status = self.create_authorization_response(
234-
request=self.request,
235-
scopes=" ".join(scopes),
236-
credentials=credentials,
237-
allow=True,
231+
request=self.request, scopes=" ".join(scopes), credentials=credentials, allow=True
238232
)
239233
return self.redirect(uri, application)
240234

241235
elif require_approval == "auto":
242236
tokens = (
243237
get_access_token_model()
244238
.objects.filter(
245-
user=request.user,
246-
application=kwargs["application"],
247-
expires__gt=timezone.now(),
239+
user=request.user, application=kwargs["application"], expires__gt=timezone.now()
248240
)
249241
.all()
250242
)
@@ -253,10 +245,7 @@ def get(self, request, *args, **kwargs):
253245
for token in tokens:
254246
if token.allow_scopes(scopes):
255247
uri, headers, body, status = self.create_authorization_response(
256-
request=self.request,
257-
scopes=" ".join(scopes),
258-
credentials=credentials,
259-
allow=True,
248+
request=self.request, scopes=" ".join(scopes), credentials=credentials, allow=True
260249
)
261250
return self.redirect(uri, application)
262251

oauth2_provider/views/oidc.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ def get(self, request, *args, **kwargs):
8383

8484
signing_algorithms = [Application.HS256_ALGORITHM]
8585
if oauth2_settings.OIDC_RSA_PRIVATE_KEY:
86-
signing_algorithms = [
87-
Application.RS256_ALGORITHM,
88-
Application.HS256_ALGORITHM,
89-
]
86+
signing_algorithms = [Application.RS256_ALGORITHM, Application.HS256_ALGORITHM]
9087

9188
validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
9289
validator = validator_class()
@@ -251,10 +248,7 @@ class RPInitiatedLogoutView(OIDCLogoutOnlyMixin, FormView):
251248
form_class = ConfirmLogoutForm
252249
# Only delete tokens for Application whose client type and authorization
253250
# grant type are in the respective lists.
254-
token_deletion_client_types = [
255-
Application.CLIENT_PUBLIC,
256-
Application.CLIENT_CONFIDENTIAL,
257-
]
251+
token_deletion_client_types = [Application.CLIENT_PUBLIC, Application.CLIENT_CONFIDENTIAL]
258252
token_deletion_grant_types = [
259253
Application.GRANT_AUTHORIZATION_CODE,
260254
Application.GRANT_IMPLICIT,
@@ -458,13 +452,7 @@ def must_prompt(self, token_user):
458452
""" We didn't find a reason to prompt the user """
459453
return False
460454

461-
def do_logout(
462-
self,
463-
application=None,
464-
post_logout_redirect_uri=None,
465-
state=None,
466-
token_user=None,
467-
):
455+
def do_logout(self, application=None, post_logout_redirect_uri=None, state=None, token_user=None):
468456
user = token_user or self.request.user
469457
# Delete Access Tokens if a user was found
470458
if oauth2_settings.OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS and not isinstance(user, AnonymousUser):
@@ -501,8 +489,7 @@ def do_logout(
501489
return OAuth2ResponseRedirect(post_logout_redirect_uri, application.get_allowed_schemes())
502490
else:
503491
return OAuth2ResponseRedirect(
504-
self.request.build_absolute_uri("/"),
505-
oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES,
492+
self.request.build_absolute_uri("/"), oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES
506493
)
507494

508495
def error_response(self, error):

tests/presets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED["OIDC_RP_INITIATED_LOGOUT_ACCEPT_EXPIRED_TOKENS"] = False
3838
OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS = deepcopy(OIDC_SETTINGS_RP_LOGOUT)
3939
OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS["OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS"] = False
40+
OIDC_SETTINGS_SESSION_MANAGEMENT = deepcopy(OIDC_SETTINGS_RW)
41+
OIDC_SETTINGS_SESSION_MANAGEMENT["OIDC_SESSION_MANAGEMENT_ENABLED"] = True
4042
REST_FRAMEWORK_SCOPES = {
4143
"SCOPES": {
4244
"read": "Read scope",

tests/test_oidc_views.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from django.contrib.auth import get_user
2+
from django.contrib.auth import get_user, get_user_model
33
from django.contrib.auth.models import AnonymousUser
44
from django.test import RequestFactory
55
from django.urls import reverse
@@ -12,7 +12,12 @@
1212
InvalidOIDCClientError,
1313
InvalidOIDCRedirectURIError,
1414
)
15-
from oauth2_provider.models import get_access_token_model, get_id_token_model, get_refresh_token_model
15+
from oauth2_provider.models import (
16+
get_access_token_model,
17+
get_application_model,
18+
get_id_token_model,
19+
get_refresh_token_model,
20+
)
1621
from oauth2_provider.oauth2_validators import OAuth2Validator
1722
from oauth2_provider.settings import oauth2_settings
1823
from oauth2_provider.views.oidc import RPInitiatedLogoutView, _load_id_token, _validate_claims
@@ -206,6 +211,42 @@ def test_get_jwks_info_multiple_rsa_keys(self):
206211
assert response.json() == expected_response
207212

208213

214+
@pytest.mark.usefixtures("oauth2_settings")
215+
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_SESSION_MANAGEMENT)
216+
class TestAuthorizationView(TestCase):
217+
def test_session_state_is_present_in_url(self):
218+
User = get_user_model()
219+
Application = get_application_model()
220+
221+
User.objects.create_user("test_user", "[email protected]", "123456")
222+
dev_user = User.objects.create_user("dev_user", "[email protected]", "123456")
223+
224+
application = Application.objects.create(
225+
name="Test Application",
226+
redirect_uris=(
227+
"http://localhost http://example.com http://example.org custom-scheme://example.com"
228+
),
229+
user=dev_user,
230+
client_type=Application.CLIENT_CONFIDENTIAL,
231+
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
232+
client_secret="1234567890qwertyuiop",
233+
)
234+
self.client.login(username="test_user", password="123456")
235+
response = self.client.post(
236+
reverse("oauth2_provider:authorize"),
237+
{
238+
"client_id": application.client_id,
239+
"response_type": "code",
240+
"state": "random_state_string",
241+
"scope": "read write",
242+
"redirect_uri": "http://example.org",
243+
"allow": True,
244+
},
245+
)
246+
self.assertEqual(response.status_code, 302)
247+
self.assertTrue("session_state" in response["Location"])
248+
249+
209250
def mock_request():
210251
"""
211252
Dummy request with an AnonymousUser attached.
@@ -467,10 +508,7 @@ def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application
467508
# Accepting expired (but otherwise valid and signed by us) tokens is enabled. Logout should go through.
468509
rsp = logged_in_client.get(
469510
reverse("oauth2_provider:rp-initiated-logout"),
470-
data={
471-
"id_token_hint": expired_id_token,
472-
"client_id": application.client_id,
473-
},
511+
data={"id_token_hint": expired_id_token, "client_id": application.client_id},
474512
)
475513
assert rsp.status_code == 302
476514
assert not is_logged_in(logged_in_client)
@@ -482,10 +520,7 @@ def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application,
482520
# Expired tokens should not be accepted by default.
483521
rsp = logged_in_client.get(
484522
reverse("oauth2_provider:rp-initiated-logout"),
485-
data={
486-
"id_token_hint": expired_id_token,
487-
"client_id": application.client_id,
488-
},
523+
data={"id_token_hint": expired_id_token, "client_id": application.client_id},
489524
)
490525
assert rsp.status_code == 400
491526
assert is_logged_in(logged_in_client)

0 commit comments

Comments
 (0)