@@ -56,6 +56,7 @@ class OAuth2Core:
56
56
_oauth_client : Optional [WebApplicationClient ] = None
57
57
_authorization_endpoint : str = None
58
58
_token_endpoint : str = None
59
+ _state : str = None
59
60
60
61
def __init__ (self , client : OAuth2Client ) -> None :
61
62
self .client_id = client .client_id
@@ -83,6 +84,8 @@ def authorization_url(self, request: Request) -> str:
83
84
oauth2_query_params = dict (state = state , scope = self .scope , redirect_uri = redirect_uri )
84
85
oauth2_query_params .update (request .query_params )
85
86
87
+ self ._state = oauth2_query_params .get ("state" )
88
+
86
89
return str (self ._oauth_client .prepare_request_uri (
87
90
self ._authorization_endpoint ,
88
91
** oauth2_query_params ,
@@ -96,6 +99,8 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
96
99
raise OAuth2LoginError (400 , "'code' parameter was not found in callback request" )
97
100
if not request .query_params .get ("state" ):
98
101
raise OAuth2LoginError (400 , "'state' parameter was not found in callback request" )
102
+ if request .query_params .get ("state" ) != self ._state :
103
+ raise OAuth2LoginError (400 , "'state' parameter does not match" )
99
104
100
105
redirect_uri = self .get_redirect_uri (request )
101
106
scheme = "http" if request .auth .http else "https"
0 commit comments