Skip to content

Commit a438ad4

Browse files
committed
Add a new attribute for SSR mode
1 parent aed3fae commit a438ad4

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

src/fastapi_oauth2/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
class OAuth2Config:
99
"""Configuration class of the authentication middleware."""
1010

11+
enable_ssr: bool
1112
allow_http: bool
1213
jwt_secret: str
1314
jwt_expires: int
@@ -17,6 +18,7 @@ class OAuth2Config:
1718
def __init__(
1819
self,
1920
*,
21+
enable_ssr: bool = True,
2022
allow_http: bool = False,
2123
jwt_secret: str = "",
2224
jwt_expires: Union[int, str] = 900,
@@ -25,6 +27,7 @@ def __init__(
2527
) -> None:
2628
if allow_http:
2729
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
30+
self.enable_ssr = enable_ssr
2831
self.allow_http = allow_http
2932
self.jwt_secret = jwt_secret
3033
self.jwt_expires = int(jwt_expires)

src/fastapi_oauth2/middleware.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
class Auth(AuthCredentials):
3232
"""Extended auth credentials schema based on Starlette AuthCredentials."""
3333

34+
ssr: bool
3435
http: bool
3536
secret: str
3637
expires: int
@@ -39,6 +40,10 @@ class Auth(AuthCredentials):
3940
provider: OAuth2Core = None
4041
clients: Dict[str, OAuth2Core] = {}
4142

43+
@classmethod
44+
def set_ssr(cls, ssr: bool) -> None:
45+
cls.ssr = ssr
46+
4247
@classmethod
4348
def set_http(cls, http: bool) -> None:
4449
cls.http = http
@@ -117,6 +122,7 @@ def __init__(
117122
config: OAuth2Config,
118123
callback: Callable[[Auth, User], Union[Awaitable[None], None]] = None,
119124
) -> None:
125+
Auth.set_ssr(config.enable_ssr)
120126
Auth.set_http(config.allow_http)
121127
Auth.set_secret(config.jwt_secret)
122128
Auth.set_expires(config.jwt_expires)

src/fastapi_oauth2/router.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66

77

88
@router.get("/{provider}/auth")
9-
async def login(request: Request, provider: str):
10-
return await request.auth.clients[provider].login_redirect(request)
9+
def authorize(request: Request, provider: str):
10+
if request.auth.ssr:
11+
return request.auth.clients[provider].authorization_redirect(request)
12+
return dict(url=request.auth.clients[provider].authorization_url(request))
1113

1214

1315
@router.get("/{provider}/token")
1416
async def token(request: Request, provider: str):
15-
return await request.auth.clients[provider].token_redirect(request)
17+
if request.auth.ssr:
18+
return await request.auth.clients[provider].token_redirect(request)
19+
return await request.auth.clients[provider].token_data(request)
1620

1721

1822
@router.get("/logout")

0 commit comments

Comments
 (0)