Skip to content

Commit 7b18c0a

Browse files
committed
GH-22: Handle possible exceptions in token data obtaining flow
1 parent 23658d3 commit 7b18c0a

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/fastapi_oauth2/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
from urllib.parse import urljoin
1010

1111
import httpx
12+
from oauthlib.oauth2 import OAuth2Error
1213
from oauthlib.oauth2 import WebApplicationClient
13-
from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
1414
from social_core.backends.oauth import BaseOAuth2
15+
from social_core.exceptions import AuthException
1516
from social_core.strategy import BaseStrategy
1617
from starlette.requests import Request
1718
from starlette.responses import RedirectResponse
1819

1920
from .claims import Claims
2021
from .client import OAuth2Client
21-
from .exceptions import OAuth2LoginError
22+
from .exceptions import OAuth2AuthenticationError
23+
from .exceptions import OAuth2BadCredentialsError
24+
from .exceptions import OAuth2InvalidRequestError
2225

2326

2427
class OAuth2Strategy(BaseStrategy):
@@ -92,11 +95,11 @@ def authorization_redirect(self, request: Request) -> RedirectResponse:
9295

9396
async def token_data(self, request: Request, **httpx_client_args) -> dict:
9497
if not request.query_params.get("code"):
95-
raise OAuth2LoginError(400, "'code' parameter was not found in callback request")
98+
raise OAuth2InvalidRequestError(400, "'code' parameter was not found in callback request")
9699
if not request.query_params.get("state"):
97-
raise OAuth2LoginError(400, "'state' parameter was not found in callback request")
100+
raise OAuth2InvalidRequestError(400, "'state' parameter was not found in callback request")
98101
if request.query_params.get("state") != self._state:
99-
raise OAuth2LoginError(400, "'state' parameter does not match")
102+
raise OAuth2InvalidRequestError(400, "'state' parameter does not match")
100103

101104
redirect_uri = self.get_redirect_uri(request)
102105
scheme = "http" if request.auth.http else "https"
@@ -113,12 +116,16 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
113116
headers.update({"Accept": "application/json"})
114117
auth = httpx.BasicAuth(self.client_id, self.client_secret)
115118
async with httpx.AsyncClient(auth=auth, **httpx_client_args) as session:
116-
response = await session.post(token_url, headers=headers, content=content)
117119
try:
120+
response = await session.post(token_url, headers=headers, content=content)
118121
self._oauth_client.parse_request_body_response(json.dumps(response.json()))
119122
return self.standardize(self.backend.user_data(self.access_token))
120-
except (CustomOAuth2Error, Exception) as e:
121-
raise OAuth2LoginError(400, str(e))
123+
except OAuth2Error as e:
124+
raise OAuth2InvalidRequestError(400, str(e))
125+
except httpx.HTTPError as e:
126+
raise OAuth2BadCredentialsError(400, str(e))
127+
except (AuthException, Exception) as e:
128+
raise OAuth2AuthenticationError(401, str(e))
122129

123130
async def token_redirect(self, request: Request, **kwargs) -> RedirectResponse:
124131
access_token = request.auth.jwt_create(await self.token_data(request, **kwargs))

0 commit comments

Comments
 (0)