diff --git a/fastapi_oauth20/callback.py b/fastapi_oauth20/callback.py index 5eb190c..7c66fbf 100644 --- a/fastapi_oauth20/callback.py +++ b/fastapi_oauth20/callback.py @@ -6,7 +6,7 @@ from fastapi import HTTPException, Request -from fastapi_oauth20.errors import AccessTokenError, HTTPXOAuth20Error, OAuth20BaseError +from fastapi_oauth20.errors import OAuth20BaseError, OAuth20RequestError from fastapi_oauth20.oauth20 import OAuth20Base @@ -79,7 +79,7 @@ async def __call__( redirect_uri=self.redirect_uri, code_verifier=code_verifier, ) - except (HTTPXOAuth20Error, AccessTokenError) as e: + except OAuth20RequestError as e: raise OAuth20AuthorizeCallbackError( status_code=500, detail=e.msg, diff --git a/fastapi_oauth20/clients/feishu.py b/fastapi_oauth20/clients/feishu.py index cd4b001..08e35b9 100644 --- a/fastapi_oauth20/clients/feishu.py +++ b/fastapi_oauth20/clients/feishu.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import httpx -from fastapi_oauth20.errors import GetUserInfoError from fastapi_oauth20.oauth20 import OAuth20Base @@ -22,23 +20,10 @@ def __init__(self, client_id: str, client_secret: str): authorize_endpoint='https://passport.feishu.cn/suite/passport/oauth/authorize', access_token_endpoint='https://passport.feishu.cn/suite/passport/oauth/token', refresh_token_endpoint='https://passport.feishu.cn/suite/passport/oauth/authorize', + userinfo_endpoint='https://passport.feishu.cn/suite/passport/oauth/userinfo', default_scopes=[ 'contact:user.employee_id:readonly', 'contact:user.base:readonly', 'contact:user.email:readonly', ], ) - - async def get_userinfo(self, access_token: str) -> dict: - """ - Retrieve user information from FeiShu API. - - :param access_token: Valid FeiShu access token with contact:user scopes. - :return: - """ - headers = {'Authorization': f'Bearer {access_token}'} - async with httpx.AsyncClient() as client: - response = await client.get('https://passport.feishu.cn/suite/passport/oauth/userinfo', headers=headers) - self.raise_httpx_oauth20_errors(response) - result = self.get_json_result(response, err_class=GetUserInfoError) - return result diff --git a/fastapi_oauth20/clients/gitee.py b/fastapi_oauth20/clients/gitee.py index eac8c0a..63720a5 100644 --- a/fastapi_oauth20/clients/gitee.py +++ b/fastapi_oauth20/clients/gitee.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import httpx -from fastapi_oauth20.errors import GetUserInfoError from fastapi_oauth20.oauth20 import OAuth20Base @@ -22,19 +20,6 @@ def __init__(self, client_id: str, client_secret: str): authorize_endpoint='https://gitee.com/oauth/authorize', access_token_endpoint='https://gitee.com/oauth/token', refresh_token_endpoint='https://gitee.com/oauth/token', + userinfo_endpoint='https://gitee.com/api/v5/user', default_scopes=['user_info'], ) - - async def get_userinfo(self, access_token: str) -> dict: - """ - Retrieve user information from Gitee API. - - :param access_token: Valid Gitee access token with user_info scope. - :return: - """ - headers = {'Authorization': f'Bearer {access_token}'} - async with httpx.AsyncClient() as client: - response = await client.get('https://gitee.com/api/v5/user', headers=headers) - self.raise_httpx_oauth20_errors(response) - result = self.get_json_result(response, err_class=GetUserInfoError) - return result diff --git a/fastapi_oauth20/clients/github.py b/fastapi_oauth20/clients/github.py index 49b95c9..68b54f7 100644 --- a/fastapi_oauth20/clients/github.py +++ b/fastapi_oauth20/clients/github.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import Any + import httpx from fastapi_oauth20.errors import GetUserInfoError @@ -21,10 +23,11 @@ def __init__(self, client_id: str, client_secret: str): client_secret=client_secret, authorize_endpoint='https://github.com/login/oauth/authorize', access_token_endpoint='https://github.com/login/oauth/access_token', + userinfo_endpoint='https://api.github.com/user', default_scopes=['user', 'user:email'], ) - async def get_userinfo(self, access_token: str) -> dict: + async def get_userinfo(self, access_token: str) -> dict[str, Any]: """ Retrieve user information from GitHub API. @@ -33,13 +36,13 @@ async def get_userinfo(self, access_token: str) -> dict: """ headers = {'Authorization': f'Bearer {access_token}'} async with httpx.AsyncClient(headers=headers) as client: - response = await client.get('https://api.github.com/user') + response = await client.get(self.userinfo_endpoint) self.raise_httpx_oauth20_errors(response) result = self.get_json_result(response, err_class=GetUserInfoError) email = result.get('email') if email is None: - response = await client.get('https://api.github.com/user/emails') + response = await client.get(f'{self.userinfo_endpoint}/emails') self.raise_httpx_oauth20_errors(response) emails = self.get_json_result(response, err_class=GetUserInfoError) email = next((email['email'] for email in emails if email.get('primary')), emails[0]['email']) diff --git a/fastapi_oauth20/clients/google.py b/fastapi_oauth20/clients/google.py index 0a7a148..c19cce0 100644 --- a/fastapi_oauth20/clients/google.py +++ b/fastapi_oauth20/clients/google.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import httpx -from fastapi_oauth20.errors import GetUserInfoError from fastapi_oauth20.oauth20 import OAuth20Base @@ -23,19 +21,6 @@ def __init__(self, client_id: str, client_secret: str): access_token_endpoint='https://oauth2.googleapis.com/token', refresh_token_endpoint='https://oauth2.googleapis.com/token', revoke_token_endpoint='https://accounts.google.com/o/oauth2/revoke', + userinfo_endpoint='https://www.googleapis.com/oauth2/v1/userinfo', default_scopes=['email', 'openid', 'profile'], ) - - async def get_userinfo(self, access_token: str) -> dict: - """ - Retrieve user information from Google OAuth2 API. - - :param access_token: Valid Google access token with appropriate scopes. - :return: - """ - headers = {'Authorization': f'Bearer {access_token}'} - async with httpx.AsyncClient() as client: - response = await client.get('https://www.googleapis.com/oauth2/v1/userinfo', headers=headers) - self.raise_httpx_oauth20_errors(response) - result = self.get_json_result(response, err_class=GetUserInfoError) - return result diff --git a/fastapi_oauth20/clients/linuxdo.py b/fastapi_oauth20/clients/linuxdo.py index 6adb0e7..690cef6 100644 --- a/fastapi_oauth20/clients/linuxdo.py +++ b/fastapi_oauth20/clients/linuxdo.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import httpx -from fastapi_oauth20.errors import GetUserInfoError from fastapi_oauth20.oauth20 import OAuth20Base @@ -22,19 +20,6 @@ def __init__(self, client_id: str, client_secret: str): authorize_endpoint='https://connect.linux.do/oauth2/authorize', access_token_endpoint='https://connect.linux.do/oauth2/token', refresh_token_endpoint='https://connect.linux.do/oauth2/token', + userinfo_endpoint='https://connect.linux.do/api/user', token_endpoint_basic_auth=True, ) - - async def get_userinfo(self, access_token: str) -> dict: - """ - Retrieve user information from Linux.do API. - - :param access_token: Valid Linux.do access token. - :return: - """ - headers = {'Authorization': f'Bearer {access_token}'} - async with httpx.AsyncClient() as client: - response = await client.get('https://connect.linux.do/api/user', headers=headers) - self.raise_httpx_oauth20_errors(response) - result = self.get_json_result(response, err_class=GetUserInfoError) - return result diff --git a/fastapi_oauth20/clients/oschina.py b/fastapi_oauth20/clients/oschina.py index 91edfb7..6a60289 100644 --- a/fastapi_oauth20/clients/oschina.py +++ b/fastapi_oauth20/clients/oschina.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import httpx -from fastapi_oauth20.errors import GetUserInfoError from fastapi_oauth20.oauth20 import OAuth20Base @@ -22,18 +20,5 @@ def __init__(self, client_id: str, client_secret: str): authorize_endpoint='https://www.oschina.net/action/oauth2/authorize', access_token_endpoint='https://www.oschina.net/action/openapi/token', refresh_token_endpoint='https://www.oschina.net/action/openapi/token', + userinfo_endpoint='https://www.oschina.net/action/openapi/user', ) - - async def get_userinfo(self, access_token: str) -> dict: - """ - Retrieve user information from OSChina API. - - :param access_token: Valid OSChina access token. - :return: - """ - headers = {'Authorization': f'Bearer {access_token}'} - async with httpx.AsyncClient() as client: - response = await client.get('https://www.oschina.net/action/openapi/user', headers=headers) - self.raise_httpx_oauth20_errors(response) - result = self.get_json_result(response, err_class=GetUserInfoError) - return result diff --git a/fastapi_oauth20/oauth20.py b/fastapi_oauth20/oauth20.py index 9b8ac91..255fc22 100644 --- a/fastapi_oauth20/oauth20.py +++ b/fastapi_oauth20/oauth20.py @@ -1,9 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import abc import json -from abc import ABC from typing import Any, Literal, cast from urllib.parse import urlencode @@ -11,6 +9,7 @@ from fastapi_oauth20.errors import ( AccessTokenError, + GetUserInfoError, HTTPXOAuth20Error, OAuth20RequestError, RefreshTokenError, @@ -18,7 +17,7 @@ ) -class OAuth20Base(ABC): +class OAuth20Base: def __init__( self, client_id: str, @@ -26,6 +25,7 @@ def __init__( *, authorize_endpoint: str, access_token_endpoint: str, + userinfo_endpoint: str, refresh_token_endpoint: str | None = None, revoke_token_endpoint: str | None = None, default_scopes: list[str] | None = None, @@ -39,6 +39,7 @@ def __init__( :param client_secret: The client secret provided by the OAuth2 provider. :param authorize_endpoint: The authorization endpoint URL where users are redirected to grant access. :param access_token_endpoint: The token endpoint URL for exchanging authorization codes for access tokens. + :param userinfo_endpoint: The endpoint URL for retrieving user information using access token. :param refresh_token_endpoint: The token endpoint URL for refreshing expired access tokens using refresh tokens. :param revoke_token_endpoint: The endpoint URL for revoking access tokens or refresh tokens. :param default_scopes: Default list of OAuth scopes to request if none are specified. @@ -51,6 +52,7 @@ def __init__( self.access_token_endpoint = access_token_endpoint self.refresh_token_endpoint = refresh_token_endpoint self.revoke_token_endpoint = revoke_token_endpoint + self.userinfo_endpoint = userinfo_endpoint self.default_scopes = default_scopes self.token_endpoint_basic_auth = token_endpoint_basic_auth self.revoke_token_endpoint_basic_auth = revoke_token_endpoint_basic_auth @@ -227,7 +229,6 @@ def get_json_result(response: httpx.Response, *, err_class: type[OAuth20RequestE except json.JSONDecodeError as e: raise err_class('Result serialization failed.', response) from e - @abc.abstractmethod async def get_userinfo(self, access_token: str) -> dict[str, Any]: """ Retrieve user information from the OAuth2 provider. @@ -235,4 +236,9 @@ async def get_userinfo(self, access_token: str) -> dict[str, Any]: :param access_token: Valid access token to authenticate the request to the provider's user info endpoint. :return: """ - raise NotImplementedError() + headers = {'Authorization': f'Bearer {access_token}'} + async with httpx.AsyncClient() as client: + response = await client.get(self.userinfo_endpoint, headers=headers) + self.raise_httpx_oauth20_errors(response) + result = self.get_json_result(response, err_class=GetUserInfoError) + return result