Skip to content

Commit 1a02068

Browse files
committed
Refactor env vars and cookie setting
1 parent 127966d commit 1a02068

File tree

7 files changed

+67
-54
lines changed

7 files changed

+67
-54
lines changed

.env

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
CLIENT_ID=eccd08d6736b7999a32a
2-
CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
3-
CALLBACK_URL=http://127.0.0.1:8000/auth/callback
4-
REDIRECT_URL=http://127.0.0.1:8000/
1+
OAUTH2_CLIENT_ID=eccd08d6736b7999a32a
2+
OAUTH2_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
3+
OAUTH2_CALLBACK_URL=http://127.0.0.1:8000/auth/callback
4+
OAUTH2_REDIRECT_URL=http://127.0.0.1:8000/
55

66
JWT_SECRET=secret
77
JWT_ALGORITHM=HS256

demo/dependencies.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
55
from fastapi.security import OAuth2
66
from fastapi.security.utils import get_authorization_scheme_param
7-
from jose import jwt, JWTError
7+
from jose import JWTError
88
from starlette.requests import Request
99
from starlette.status import HTTP_403_FORBIDDEN
1010

11-
from fastapi_oauth2.config import JWT_SECRET, JWT_ALGORITHM
11+
from fastapi_oauth2.utils import jwt_decode
1212

1313

1414
class OAuth2PasswordBearerCookie(OAuth2):
@@ -44,7 +44,7 @@ async def __call__(self, request: Request) -> Optional[str]:
4444

4545
async def get_current_user(token: str = Depends(oauth2_scheme)):
4646
try:
47-
return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
47+
return jwt_decode(token)
4848
except JWTError:
4949
raise HTTPException(
5050
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"

fastapi_oauth2/base.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
from starlette.requests import Request
1010
from starlette.responses import RedirectResponse
1111

12+
from .config import JWT_EXPIRES, OAUTH2_REDIRECT_URL
13+
from .utils import create_access_token
1214

13-
class SSOLoginError(HTTPException):
15+
16+
class OAuth2LoginError(HTTPException):
1417
"""Raised when any login-related error occurs
1518
(such as when user is not verified or if there was an attempt for fake login)
1619
"""
1720

1821

19-
class SSOBase:
22+
class OAuth2Base:
2023
"""Base class (mixin) for all SSO providers"""
2124

2225
client_id: str = None
@@ -78,7 +81,7 @@ async def get_login_url(
7881
self.authorization_endpoint, redirect_uri=redirect_uri, state=state, scope=self.scope, **params
7982
)
8083

81-
async def get_login_redirect(
84+
async def login_redirect(
8285
self,
8386
*,
8487
redirect_uri: Optional[str] = None,
@@ -88,7 +91,7 @@ async def get_login_redirect(
8891
login_uri = await self.get_login_url(redirect_uri=redirect_uri, params=params, state=state)
8992
return RedirectResponse(login_uri, 303)
9093

91-
async def verify_and_process(
94+
async def get_token_data(
9295
self,
9396
request: Request,
9497
*,
@@ -100,9 +103,9 @@ async def verify_and_process(
100103
additional_headers = headers or {}
101104
additional_headers.update(self.additional_headers or {})
102105
if not request.query_params.get("code"):
103-
raise SSOLoginError(400, "'code' parameter was not found in callback request")
106+
raise OAuth2LoginError(400, "'code' parameter was not found in callback request")
104107
if self.state != request.query_params.get("state"):
105-
raise SSOLoginError(400, "'state' parameter does not match")
108+
raise OAuth2LoginError(400, "'state' parameter does not match")
106109

107110
url = request.url
108111
scheme = "http" if self.allow_insecure_http else "https"
@@ -129,3 +132,23 @@ async def verify_and_process(
129132
content = response.json()
130133

131134
return content
135+
136+
async def token_redirect(
137+
self,
138+
request: Request,
139+
*,
140+
params: Optional[Dict[str, Any]] = None,
141+
headers: Optional[Dict[str, Any]] = None,
142+
redirect_uri: Optional[str] = None,
143+
) -> RedirectResponse:
144+
token_data = await self.get_token_data(request, params=params, headers=headers, redirect_uri=redirect_uri)
145+
access_token = create_access_token(token_data)
146+
response = RedirectResponse(OAUTH2_REDIRECT_URL)
147+
response.set_cookie(
148+
"Authorization",
149+
value=f"Bearer {access_token}",
150+
httponly=self.allow_insecure_http,
151+
max_age=JWT_EXPIRES * 60,
152+
expires=JWT_EXPIRES * 60,
153+
)
154+
return response

fastapi_oauth2/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
load_dotenv()
66

7-
CLIENT_ID = os.getenv("CLIENT_ID")
8-
CLIENT_SECRET = os.getenv("CLIENT_SECRET")
9-
CALLBACK_URL = os.getenv("CALLBACK_URL")
10-
REDIRECT_URL = os.getenv("REDIRECT_URL")
7+
OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID")
8+
OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET")
9+
OAUTH2_CALLBACK_URL = os.getenv("OAUTH2_CALLBACK_URL")
10+
OAUTH2_REDIRECT_URL = os.getenv("OAUTH2_REDIRECT_URL")
1111

1212
JWT_SECRET = os.getenv("JWT_SECRET")
1313
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM")
14-
JWT_EXPIRES = int(os.getenv("JWT_EXPIRES"))
14+
JWT_EXPIRES = int(os.getenv("JWT_EXPIRES", "15"))

fastapi_oauth2/github.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .base import SSOBase
1+
from .base import OAuth2Base
22

33

4-
class GitHubSSO(SSOBase):
4+
class GitHubOAuth2(OAuth2Base):
55
"""Class providing login via GitHub SSO"""
66

77
scope = ["user:email"]

fastapi_oauth2/router.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,36 @@
1-
from datetime import timedelta
2-
31
from fastapi import APIRouter
42
from fastapi.responses import RedirectResponse
53
from starlette.requests import Request
64

7-
from fastapi_oauth2.github import GitHubSSO
5+
from fastapi_oauth2.github import GitHubOAuth2
86
from .config import (
9-
CLIENT_ID,
10-
CLIENT_SECRET,
11-
CALLBACK_URL,
12-
JWT_EXPIRES,
13-
REDIRECT_URL,
7+
OAUTH2_CLIENT_ID,
8+
OAUTH2_CLIENT_SECRET,
9+
OAUTH2_CALLBACK_URL,
10+
OAUTH2_REDIRECT_URL,
1411
)
15-
from .utils import create_access_token
1612

1713
router = APIRouter()
18-
sso = GitHubSSO(
19-
client_id=CLIENT_ID,
20-
client_secret=CLIENT_SECRET,
21-
callback_url=CALLBACK_URL,
14+
oauth2 = GitHubOAuth2(
15+
client_id=OAUTH2_CLIENT_ID,
16+
client_secret=OAUTH2_CLIENT_SECRET,
17+
callback_url=OAUTH2_CALLBACK_URL,
2218
allow_insecure_http=True,
2319
)
2420

2521

2622
@router.get("/auth/login")
2723
async def login():
28-
return await sso.get_login_redirect()
24+
return await oauth2.login_redirect()
2925

3026

3127
@router.get("/auth/callback")
3228
async def callback(request: Request):
33-
user = await sso.verify_and_process(request)
34-
expires_delta = timedelta(minutes=JWT_EXPIRES)
35-
access_token = create_access_token(
36-
data=dict(user), expires_delta=expires_delta
37-
)
38-
response = RedirectResponse(REDIRECT_URL)
39-
response.set_cookie(
40-
"Authorization",
41-
value=f"Bearer {access_token}",
42-
httponly=sso.allow_insecure_http,
43-
max_age=JWT_EXPIRES * 60,
44-
expires=JWT_EXPIRES * 60,
45-
)
46-
return response
29+
return await oauth2.token_redirect(request)
4730

4831

4932
@router.get("/auth/logout")
5033
async def logout():
51-
response = RedirectResponse(REDIRECT_URL)
34+
response = RedirectResponse(OAUTH2_REDIRECT_URL)
5235
response.delete_cookie("Authorization")
5336
return response

fastapi_oauth2/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from datetime import datetime, timedelta
2-
from typing import Optional
32

43
from jose import jwt
54

6-
from .config import JWT_SECRET, JWT_ALGORITHM
5+
from .config import JWT_SECRET, JWT_ALGORITHM, JWT_EXPIRES
76

87

9-
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
10-
expire = datetime.utcnow() + expires_delta if expires_delta else timedelta(minutes=15)
11-
return jwt.encode({**data, "exp": expire}, JWT_SECRET, algorithm=JWT_ALGORITHM)
8+
def jwt_encode(data: dict) -> str:
9+
return jwt.encode(data, JWT_SECRET, algorithm=JWT_ALGORITHM)
10+
11+
12+
def jwt_decode(token: str) -> dict:
13+
return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
14+
15+
16+
def create_access_token(token_data: dict) -> str:
17+
expire = datetime.utcnow() + timedelta(minutes=JWT_EXPIRES)
18+
return jwt_encode({**token_data, "exp": expire})

0 commit comments

Comments
 (0)