9
9
from urllib .parse import urljoin
10
10
11
11
import httpx
12
+ from oauthlib .oauth2 import OAuth2Error
12
13
from oauthlib .oauth2 import WebApplicationClient
13
- from oauthlib .oauth2 .rfc6749 .errors import CustomOAuth2Error
14
14
from social_core .backends .oauth import BaseOAuth2
15
+ from social_core .exceptions import AuthException
15
16
from social_core .strategy import BaseStrategy
16
17
from starlette .requests import Request
17
18
from starlette .responses import RedirectResponse
18
19
19
20
from .claims import Claims
20
21
from .client import OAuth2Client
21
- from .exceptions import OAuth2LoginError
22
+ from .exceptions import OAuth2AuthenticationError
23
+ from .exceptions import OAuth2BadCredentialsError
24
+ from .exceptions import OAuth2InvalidRequestError
22
25
23
26
24
27
class OAuth2Strategy (BaseStrategy ):
@@ -92,11 +95,11 @@ def authorization_redirect(self, request: Request) -> RedirectResponse:
92
95
93
96
async def token_data (self , request : Request , ** httpx_client_args ) -> dict :
94
97
if not request .query_params .get ("code" ):
95
- raise OAuth2LoginError (400 , "'code' parameter was not found in callback request" )
98
+ raise OAuth2InvalidRequestError (400 , "'code' parameter was not found in callback request" )
96
99
if not request .query_params .get ("state" ):
97
- raise OAuth2LoginError (400 , "'state' parameter was not found in callback request" )
100
+ raise OAuth2InvalidRequestError (400 , "'state' parameter was not found in callback request" )
98
101
if request .query_params .get ("state" ) != self ._state :
99
- raise OAuth2LoginError (400 , "'state' parameter does not match" )
102
+ raise OAuth2InvalidRequestError (400 , "'state' parameter does not match" )
100
103
101
104
redirect_uri = self .get_redirect_uri (request )
102
105
scheme = "http" if request .auth .http else "https"
@@ -113,12 +116,16 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
113
116
headers .update ({"Accept" : "application/json" })
114
117
auth = httpx .BasicAuth (self .client_id , self .client_secret )
115
118
async with httpx .AsyncClient (auth = auth , ** httpx_client_args ) as session :
116
- response = await session .post (token_url , headers = headers , content = content )
117
119
try :
120
+ response = await session .post (token_url , headers = headers , content = content )
118
121
self ._oauth_client .parse_request_body_response (json .dumps (response .json ()))
119
122
return self .standardize (self .backend .user_data (self .access_token ))
120
- except (CustomOAuth2Error , Exception ) as e :
121
- raise OAuth2LoginError (400 , str (e ))
123
+ except OAuth2Error as e :
124
+ raise OAuth2InvalidRequestError (400 , str (e ))
125
+ except httpx .HTTPError as e :
126
+ raise OAuth2BadCredentialsError (400 , str (e ))
127
+ except (AuthException , Exception ) as e :
128
+ raise OAuth2AuthenticationError (401 , str (e ))
122
129
123
130
async def token_redirect (self , request : Request , ** kwargs ) -> RedirectResponse :
124
131
access_token = request .auth .jwt_create (await self .token_data (request , ** kwargs ))
0 commit comments