1
+ import json
1
2
from urllib .parse import parse_qs , urlparse
2
3
3
4
import pytest
6
7
from django .urls import reverse
7
8
8
9
from oauth2_provider .models import get_application_model
9
- from oauth2_provider .oauth2_validators import OAuth2Validator
10
10
11
11
from . import presets
12
12
from .utils import get_basic_auth_header
13
13
14
14
15
- class CorsOAuth2Validator (OAuth2Validator ):
16
- def is_origin_allowed (self , client_id , origin , request , * args , ** kwargs ):
17
- """Enable CORS in OAuthLib"""
18
- return True
19
-
20
-
21
15
Application = get_application_model ()
22
16
UserModel = get_user_model ()
23
17
@@ -50,10 +44,10 @@ def setUp(self):
50
44
client_type = Application .CLIENT_CONFIDENTIAL ,
51
45
authorization_grant_type = Application .GRANT_AUTHORIZATION_CODE ,
52
46
client_secret = CLEARTEXT_SECRET ,
47
+ allowed_origins = CLIENT_URI ,
53
48
)
54
49
55
50
self .oauth2_settings .ALLOWED_REDIRECT_URI_SCHEMES = ["https" ]
56
- self .oauth2_settings .OAUTH2_VALIDATOR_CLASS = CorsOAuth2Validator
57
51
58
52
def tearDown (self ):
59
53
self .application .delete ()
@@ -76,10 +70,42 @@ def test_cors_header(self):
76
70
auth_headers = get_basic_auth_header (self .application .client_id , CLEARTEXT_SECRET )
77
71
auth_headers ["HTTP_ORIGIN" ] = CLIENT_URI
78
72
response = self .client .post (reverse ("oauth2_provider:token" ), data = token_request_data , ** auth_headers )
73
+
74
+ content = json .loads (response .content .decode ("utf-8" ))
75
+
76
+ self .assertEqual (response .status_code , 200 )
77
+ self .assertEqual (response ["Access-Control-Allow-Origin" ], CLIENT_URI )
78
+
79
+ token_request_data = {
80
+ "grant_type" : "refresh_token" ,
81
+ "refresh_token" : content ["refresh_token" ],
82
+ "scope" : content ["scope" ],
83
+ }
84
+ response = self .client .post (reverse ("oauth2_provider:token" ), data = token_request_data , ** auth_headers )
79
85
self .assertEqual (response .status_code , 200 )
80
86
self .assertEqual (response ["Access-Control-Allow-Origin" ], CLIENT_URI )
81
87
82
- def test_no_cors_header (self ):
88
+ def test_no_cors_header_origin_not_allowed (self ):
89
+ """
90
+ Test that /token endpoint does not have Access-Control-Allow-Origin
91
+ when request origin is not in Application.allowed_origins
92
+ """
93
+ authorization_code = self ._get_authorization_code ()
94
+
95
+ # exchange authorization code for a valid access token
96
+ token_request_data = {
97
+ "grant_type" : "authorization_code" ,
98
+ "code" : authorization_code ,
99
+ "redirect_uri" : CLIENT_URI ,
100
+ }
101
+
102
+ auth_headers = get_basic_auth_header (self .application .client_id , CLEARTEXT_SECRET )
103
+ auth_headers ["HTTP_ORIGIN" ] = "another_example.org"
104
+ response = self .client .post (reverse ("oauth2_provider:token" ), data = token_request_data , ** auth_headers )
105
+ self .assertEqual (response .status_code , 200 )
106
+ self .assertFalse (response .has_header ("Access-Control-Allow-Origin" ))
107
+
108
+ def test_no_cors_header_no_origin (self ):
83
109
"""
84
110
Test that /token endpoint does not have Access-Control-Allow-Origin
85
111
"""
0 commit comments