1+ import json
12from urllib .parse import parse_qs , urlparse
23
34import pytest
67from django .urls import reverse
78
89from oauth2_provider .models import get_application_model
9- from oauth2_provider .oauth2_validators import OAuth2Validator
1010
1111from . import presets
1212from .utils import get_basic_auth_header
1313
1414
15- class CorsOAuth2Validator (OAuth2Validator ):
16- def is_origin_allowed (self , client_id , origin , request , * args , ** kwargs ):
17- """Enable CORS in OAuthLib"""
18- return True
19-
20-
2115Application = get_application_model ()
2216UserModel = get_user_model ()
2317
@@ -50,10 +44,10 @@ def setUp(self):
5044 client_type = Application .CLIENT_CONFIDENTIAL ,
5145 authorization_grant_type = Application .GRANT_AUTHORIZATION_CODE ,
5246 client_secret = CLEARTEXT_SECRET ,
47+ allowed_origins = CLIENT_URI ,
5348 )
5449
5550 self .oauth2_settings .ALLOWED_REDIRECT_URI_SCHEMES = ["https" ]
56- self .oauth2_settings .OAUTH2_VALIDATOR_CLASS = CorsOAuth2Validator
5751
5852 def tearDown (self ):
5953 self .application .delete ()
@@ -76,10 +70,42 @@ def test_cors_header(self):
7670 auth_headers = get_basic_auth_header (self .application .client_id , CLEARTEXT_SECRET )
7771 auth_headers ["HTTP_ORIGIN" ] = CLIENT_URI
7872 response = self .client .post (reverse ("oauth2_provider:token" ), data = token_request_data , ** auth_headers )
73+
74+ content = json .loads (response .content .decode ("utf-8" ))
75+
76+ self .assertEqual (response .status_code , 200 )
77+ self .assertEqual (response ["Access-Control-Allow-Origin" ], CLIENT_URI )
78+
79+ token_request_data = {
80+ "grant_type" : "refresh_token" ,
81+ "refresh_token" : content ["refresh_token" ],
82+ "scope" : content ["scope" ],
83+ }
84+ response = self .client .post (reverse ("oauth2_provider:token" ), data = token_request_data , ** auth_headers )
7985 self .assertEqual (response .status_code , 200 )
8086 self .assertEqual (response ["Access-Control-Allow-Origin" ], CLIENT_URI )
8187
82- def test_no_cors_header (self ):
88+ def test_no_cors_header_origin_not_allowed (self ):
89+ """
90+ Test that /token endpoint does not have Access-Control-Allow-Origin
91+ when request origin is not in Application.allowed_origins
92+ """
93+ authorization_code = self ._get_authorization_code ()
94+
95+ # exchange authorization code for a valid access token
96+ token_request_data = {
97+ "grant_type" : "authorization_code" ,
98+ "code" : authorization_code ,
99+ "redirect_uri" : CLIENT_URI ,
100+ }
101+
102+ auth_headers = get_basic_auth_header (self .application .client_id , CLEARTEXT_SECRET )
103+ auth_headers ["HTTP_ORIGIN" ] = "another_example.org"
104+ response = self .client .post (reverse ("oauth2_provider:token" ), data = token_request_data , ** auth_headers )
105+ self .assertEqual (response .status_code , 200 )
106+ self .assertFalse (response .has_header ("Access-Control-Allow-Origin" ))
107+
108+ def test_no_cors_header_no_origin (self ):
83109 """
84110 Test that /token endpoint does not have Access-Control-Allow-Origin
85111 """
0 commit comments