Skip to content

Commit c063340

Browse files
committed
Combine process_login and verify_and_process
1 parent f8b84e8 commit c063340

File tree

3 files changed

+26
-54
lines changed

3 files changed

+26
-54
lines changed

demo/router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi.responses import RedirectResponse
66
from starlette.requests import Request
77

8-
from fastapi_oauth2.github import GithubSSO
8+
from fastapi_oauth2.github import GitHubSSO
99
from .config import (
1010
CLIENT_ID,
1111
CLIENT_SECRET,
@@ -17,7 +17,7 @@
1717
from .utils import create_access_token
1818

1919
router = APIRouter()
20-
sso = GithubSSO(
20+
sso = GitHubSSO(
2121
client_id=CLIENT_ID,
2222
client_secret=CLIENT_SECRET,
2323
redirect_uri=redirect_url,

fastapi_oauth2/base.py

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import re
34
from typing import Any, Dict, List, Optional
45

56
import httpx
@@ -18,17 +19,18 @@ class SSOLoginError(HTTPException):
1819
class SSOBase:
1920
"""Base class (mixin) for all SSO providers"""
2021

21-
provider: str = NotImplemented
22-
client_id: str = NotImplemented
23-
client_secret: str = NotImplemented
24-
redirect_uri: Optional[str] = NotImplemented
25-
scope: List[str] = NotImplemented
22+
client_id: str = None
23+
client_secret: str = None
24+
redirect_uri: Optional[str] = None
25+
allow_insecure_http: bool = False
26+
scope: Optional[List[str]] = None
27+
state: Optional[str] = None
2628
_oauth_client: Optional[WebApplicationClient] = None
2729
additional_headers: Optional[Dict[str, Any]] = None
2830

29-
authorization_endpoint: str = NotImplemented
30-
token_endpoint: str = NotImplemented
31-
userinfo_endpoint: str = NotImplemented
31+
authorization_endpoint: str = None
32+
token_endpoint: str = None
33+
userinfo_endpoint: str = None
3234

3335
def __init__(
3436
self,
@@ -45,12 +47,9 @@ def __init__(
4547
if allow_insecure_http:
4648
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
4749
self.scope = scope or self.scope
48-
self.state: Optional[str] = None
4950

5051
@property
5152
def oauth_client(self) -> WebApplicationClient:
52-
if self.client_id == NotImplemented:
53-
raise NotImplementedError(f"Provider {self.provider} not supported")
5453
if self._oauth_client is None:
5554
self._oauth_client = WebApplicationClient(self.client_id)
5655
return self._oauth_client
@@ -63,10 +62,6 @@ def access_token(self) -> Optional[str]:
6362
def refresh_token(self) -> Optional[str]:
6463
return self.oauth_client.refresh_token
6564

66-
@classmethod
67-
async def openid_from_response(cls, response: dict) -> dict:
68-
raise NotImplementedError(f"Provider {cls.provider} not supported")
69-
7065
async def get_login_url(
7166
self,
7267
*,
@@ -101,58 +96,36 @@ async def verify_and_process(
10196
headers: Optional[Dict[str, Any]] = None,
10297
redirect_uri: Optional[str] = None,
10398
) -> Optional[dict]:
104-
headers = headers or {}
105-
code = request.query_params.get("code")
106-
if code is None:
99+
params = params or {}
100+
additional_headers = headers or {}
101+
additional_headers.update(self.additional_headers or {})
102+
if not request.query_params.get("code"):
107103
raise SSOLoginError(400, "'code' parameter was not found in callback request")
108104
if self.state != request.query_params.get("state"):
109105
raise SSOLoginError(400, "'state' parameter does not match")
110-
return await self.process_login(
111-
code, request, params=params, additional_headers=headers, redirect_uri=redirect_uri
112-
)
113106

114-
async def process_login(
115-
self,
116-
code: str,
117-
request: Request,
118-
*,
119-
params: Optional[Dict[str, Any]] = None,
120-
additional_headers: Optional[Dict[str, Any]] = None,
121-
redirect_uri: Optional[str] = None,
122-
) -> Optional[dict]:
123-
params = params or {}
124-
additional_headers = additional_headers or {}
125-
additional_headers.update(self.additional_headers or {})
126107
url = request.url
127-
scheme = url.scheme
128-
if not self.allow_insecure_http and scheme != "https":
129-
current_url = str(url).replace("http://", "https://")
130-
scheme = "https"
131-
else:
132-
current_url = str(url)
108+
scheme = "http" if self.allow_insecure_http else "https"
133109
current_path = f"{scheme}://{url.netloc}{url.path}"
110+
current_path = re.sub(r"^https?", scheme, current_path)
111+
current_url = re.sub(r"^https?", scheme, str(url))
134112

135-
token_url, headers, body = self.oauth_client.prepare_token_request(
113+
token_url, headers, content = self.oauth_client.prepare_token_request(
136114
self.token_endpoint,
137115
authorization_response=current_url,
138116
redirect_url=redirect_uri or self.redirect_uri or current_path,
139-
code=code,
117+
code=request.query_params.get("code"),
140118
**params,
141119
)
142120

143-
if token_url is None:
144-
return None
145-
146121
headers.update(additional_headers)
147-
148122
auth = httpx.BasicAuth(self.client_id, self.client_secret)
149123
async with httpx.AsyncClient() as session:
150-
response = await session.post(token_url, headers=headers, content=body, auth=auth)
151-
content = response.json()
152-
self.oauth_client.parse_request_body_response(json.dumps(content))
124+
response = await session.post(token_url, headers=headers, content=content, auth=auth)
125+
self.oauth_client.parse_request_body_response(json.dumps(response.json()))
153126

154-
uri, headers, _ = self.oauth_client.add_token(self.userinfo_endpoint)
155-
response = await session.get(uri, headers=headers)
127+
url, headers, _ = self.oauth_client.add_token(self.userinfo_endpoint)
128+
response = await session.get(url, headers=headers)
156129
content = response.json()
157130

158131
return content

fastapi_oauth2/github.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from .base import SSOBase
22

33

4-
class GithubSSO(SSOBase):
4+
class GitHubSSO(SSOBase):
55
"""Class providing login via GitHub SSO"""
66

7-
provider = "github"
87
scope = ["user:email"]
98
additional_headers = {"accept": "application/json"}
109

0 commit comments

Comments
 (0)