Skip to content

Commit d0e208e

Browse files
committed
Create the core middleware and integrate
1 parent b8d2c50 commit d0e208e

File tree

5 files changed

+89
-35
lines changed

5 files changed

+89
-35
lines changed

main.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22

33
from fastapi import FastAPI, Request, APIRouter
44
from fastapi.responses import HTMLResponse
5-
from fastapi.security.utils import get_authorization_scheme_param
65
from fastapi.templating import Jinja2Templates
7-
from starlette.authentication import AuthenticationBackend
8-
from starlette.middleware.authentication import AuthenticationMiddleware
96

10-
from demo.dependencies import get_current_user
117
from demo.router import router as demo_router
8+
from fastapi_oauth2.middleware import OAuth2Middleware
129
from fastapi_oauth2.router import router as oauth2_router
1310

1411
router = APIRouter()
@@ -24,19 +21,13 @@ async def root(request: Request):
2421
app.include_router(router)
2522
app.include_router(demo_router)
2623
app.include_router(oauth2_router)
27-
28-
29-
class BearerTokenAuthBackend(AuthenticationBackend):
30-
async def authenticate(self, request):
31-
authorization = request.cookies.get("Authorization")
32-
scheme, param = get_authorization_scheme_param(authorization)
33-
34-
if not scheme or not param:
35-
return "", None
36-
37-
return authorization, await get_current_user(param)
38-
39-
40-
@app.on_event('startup')
41-
async def startup():
42-
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend())
24+
app.add_middleware(OAuth2Middleware, config={
25+
"allow_http": True,
26+
"providers": {
27+
"github": {
28+
"client_id": "eccd08d6736b7999a32a",
29+
"client_secret": "642999c1c5f2b3df8b877afdc78252ef5b594d31",
30+
"redirect_uri": "http://127.0.0.1:8000/",
31+
},
32+
}
33+
})

src/fastapi_oauth2/base.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from starlette.responses import RedirectResponse
1111

1212
from .config import JWT_EXPIRES, OAUTH2_REDIRECT_URL
13-
from .utils import create_access_token
13+
from .utils import jwt_create
1414

1515

1616
class OAuth2LoginError(HTTPException):
@@ -68,27 +68,22 @@ def refresh_token(self) -> Optional[str]:
6868
async def get_login_url(
6969
self,
7070
*,
71-
redirect_uri: Optional[str] = None,
7271
params: Optional[Dict[str, Any]] = None,
7372
state: Optional[str] = None,
7473
) -> Any:
7574
self.state = state
7675
params = params or {}
77-
redirect_uri = redirect_uri or self.callback_url
78-
if redirect_uri is None:
79-
raise ValueError("callback_url must be provided, either at construction or request time")
8076
return self.oauth_client.prepare_request_uri(
81-
self.authorization_endpoint, redirect_uri=redirect_uri, state=state, scope=self.scope, **params
77+
self.authorization_endpoint, redirect_uri=self.callback_url, state=state, scope=self.scope, **params
8278
)
8379

8480
async def login_redirect(
8581
self,
8682
*,
87-
redirect_uri: Optional[str] = None,
8883
params: Optional[Dict[str, Any]] = None,
8984
state: Optional[str] = None,
9085
) -> RedirectResponse:
91-
login_uri = await self.get_login_url(redirect_uri=redirect_uri, params=params, state=state)
86+
login_uri = await self.get_login_url(params=params, state=state)
9287
return RedirectResponse(login_uri, 303)
9388

9489
async def get_token_data(
@@ -97,7 +92,6 @@ async def get_token_data(
9792
*,
9893
params: Optional[Dict[str, Any]] = None,
9994
headers: Optional[Dict[str, Any]] = None,
100-
redirect_uri: Optional[str] = None,
10195
) -> Optional[Dict[str, Any]]:
10296
params = params or {}
10397
additional_headers = headers or {}
@@ -116,7 +110,7 @@ async def get_token_data(
116110
token_url, headers, content = self.oauth_client.prepare_token_request(
117111
self.token_endpoint,
118112
authorization_response=current_url,
119-
redirect_url=redirect_uri or self.callback_url or current_path,
113+
redirect_url=self.callback_url or current_path,
120114
code=request.query_params.get("code"),
121115
**params,
122116
)
@@ -131,18 +125,17 @@ async def get_token_data(
131125
response = await session.get(url, headers=headers)
132126
content = response.json()
133127

134-
return content
128+
return {**content, "scope": self.scope}
135129

136130
async def token_redirect(
137131
self,
138132
request: Request,
139133
*,
140134
params: Optional[Dict[str, Any]] = None,
141135
headers: Optional[Dict[str, Any]] = None,
142-
redirect_uri: Optional[str] = None,
143136
) -> RedirectResponse:
144-
token_data = await self.get_token_data(request, params=params, headers=headers, redirect_uri=redirect_uri)
145-
access_token = create_access_token(token_data)
137+
token_data = await self.get_token_data(request, params=params, headers=headers)
138+
access_token = jwt_create(token_data)
146139
response = RedirectResponse(OAUTH2_REDIRECT_URL)
147140
response.set_cookie(
148141
"Authorization",

src/fastapi_oauth2/middleware.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Optional, Tuple, Union
2+
3+
from fastapi.security.utils import get_authorization_scheme_param
4+
from starlette.authentication import AuthenticationBackend, AuthCredentials
5+
from starlette.middleware.authentication import AuthenticationMiddleware
6+
from starlette.requests import Request
7+
from starlette.types import Send, Receive, Scope, ASGIApp
8+
9+
from .types import Config
10+
from .types import ConfigParams
11+
from .utils import jwt_decode
12+
13+
14+
class OAuth2Backend(AuthenticationBackend):
15+
async def authenticate(self, request: Request) -> Optional[Tuple["AuthCredentials", Optional[dict]]]:
16+
authorization = request.cookies.get("Authorization")
17+
scheme, param = get_authorization_scheme_param(authorization)
18+
19+
if not scheme or not param:
20+
return AuthCredentials(), None
21+
22+
access_token = jwt_decode(param)
23+
scope = access_token.pop("scope")
24+
return AuthCredentials(scope), access_token
25+
26+
27+
class OAuth2Middleware:
28+
def __init__(self, app: ASGIApp, config: Union[Config, ConfigParams]) -> None:
29+
if isinstance(config, Config):
30+
self.config = config
31+
elif isinstance(config, dict):
32+
self.config = Config(**config)
33+
else:
34+
raise ValueError("config does not contain valid parameters")
35+
self.auth_middleware = AuthenticationMiddleware(app, OAuth2Backend())
36+
37+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
38+
await self.auth_middleware(scope, receive, send)

src/fastapi_oauth2/types.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from enum import Enum
2+
from typing import Dict, TypedDict
3+
4+
5+
class OAuth2Provider(str, Enum):
6+
github = "github"
7+
8+
9+
class OAuth2Client(Dict[str, str]):
10+
client_id: str
11+
client_secret: str
12+
redirect_uri: str
13+
14+
15+
class ConfigParams(TypedDict):
16+
allow_http: bool
17+
jwt_secret: str
18+
jwt_expires: int
19+
jwt_algorithm: str
20+
providers: Dict[OAuth2Provider, OAuth2Client]
21+
22+
23+
class Config:
24+
allow_http: bool = False
25+
jwt_secret: str = ""
26+
jwt_expires: int = 900
27+
jwt_algorithm: str = "HS256"
28+
providers: Dict[OAuth2Provider, OAuth2Client] = {}
29+
30+
def __init__(self, **kwargs):
31+
for key, value in kwargs.items():
32+
setattr(self, key, value)

src/fastapi_oauth2/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ def jwt_decode(token: str) -> dict:
1313
return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
1414

1515

16-
def create_access_token(token_data: dict) -> str:
16+
def jwt_create(token_data: dict) -> str:
1717
expire = datetime.utcnow() + timedelta(minutes=JWT_EXPIRES)
1818
return jwt_encode({**token_data, "exp": expire})

0 commit comments

Comments
 (0)