Skip to content

Commit d8da427

Browse files
bilalabbadogenstad
andauthored
Implement redirection to initial URL after SSO login (#4557)
* Add final_url on oauth2 * Add final_url to oidc * handle final_url on frontend --------- Co-authored-by: Patrick Ogenstad <[email protected]>
1 parent 1ef9a8e commit d8da427

File tree

5 files changed

+49
-25
lines changed

5 files changed

+49
-25
lines changed

backend/infrahub/api/oauth2.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ def _get_redirect_url(request: Request, provider_name: str) -> str:
3131

3232

3333
@router.get("/{provider_name:str}/authorize")
34-
async def authorize(
35-
request: Request,
36-
provider_name: str,
37-
) -> Response:
34+
async def authorize(request: Request, provider_name: str, final_url: str | None = None) -> Response:
3835
provider = config.SETTINGS.security.get_oauth2_provider(provider=provider_name)
3936
client = AsyncOAuth2Client(
4037
client_id=provider.client_id,
@@ -43,15 +40,16 @@ async def authorize(
4340
)
4441

4542
redirect_uri = _get_redirect_url(request=request, provider_name=provider_name)
43+
final_url = final_url or config.SETTINGS.dev.frontend_url or str(request.base_url)
4644

4745
authorization_uri, state = client.create_authorization_url(
48-
url=provider.authorization_url, redirect_uri=redirect_uri, scope=provider.scopes
46+
url=provider.authorization_url, redirect_uri=redirect_uri, scope=provider.scopes, final_url=final_url
4947
)
5048

5149
service: InfrahubServices = request.app.state.service
5250

5351
await service.cache.set(
54-
key=f"security:oauth2:provider:{provider_name}:state:{state}", value=state, expires=KVTTL.TWO_HOURS
52+
key=f"security:oauth2:provider:{provider_name}:state:{state}", value=final_url, expires=KVTTL.TWO_HOURS
5553
)
5654

5755
if config.SETTINGS.dev.frontend_redirect_sso:
@@ -68,16 +66,16 @@ async def token(
6866
state: str,
6967
code: str,
7068
db: InfrahubDatabase = Depends(get_db),
71-
) -> models.UserToken:
69+
) -> models.UserTokenWithUrl:
7270
provider = config.SETTINGS.security.get_oauth2_provider(provider=provider_name)
7371

7472
service: InfrahubServices = request.app.state.service
7573

7674
cache_key = f"security:oauth2:provider:{provider_name}:state:{state}"
77-
stored_state = await service.cache.get(key=cache_key)
75+
stored_final_url = await service.cache.get(key=cache_key)
7876
await service.cache.delete(key=cache_key)
7977

80-
if state != stored_state:
78+
if not stored_final_url:
8179
raise ProcessingError(message="Invalid 'state' parameter")
8280

8381
token_data = {
@@ -109,7 +107,9 @@ async def token(
109107
max_age=config.SETTINGS.security.refresh_token_lifetime,
110108
)
111109

112-
return user_token
110+
return models.UserTokenWithUrl(
111+
access_token=user_token.access_token, refresh_token=user_token.refresh_token, final_url=stored_final_url
112+
)
113113

114114

115115
def _validate_response(response: httpx.Response) -> None:

backend/infrahub/api/oidc.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,7 @@ def _get_redirect_url(request: Request, provider_name: str) -> str:
6060

6161

6262
@router.get("/{provider_name:str}/authorize")
63-
async def authorize(
64-
request: Request,
65-
provider_name: str,
66-
) -> Response:
63+
async def authorize(request: Request, provider_name: str, final_url: str | None = None) -> Response:
6764
provider = config.SETTINGS.security.get_oidc_provider(provider=provider_name)
6865
service: InfrahubServices = request.app.state.service
6966

@@ -78,13 +75,14 @@ async def authorize(
7875
)
7976

8077
redirect_uri = _get_redirect_url(request=request, provider_name=provider_name)
78+
final_url = final_url or config.SETTINGS.dev.frontend_url or str(request.base_url)
8179

8280
authorization_uri, state = client.create_authorization_url(
8381
url=str(oidc_config.authorization_endpoint), redirect_uri=redirect_uri, scope=provider.scopes
8482
)
8583

8684
await service.cache.set(
87-
key=f"security:oidc:provider:{provider_name}:state:{state}", value=state, expires=KVTTL.TWO_HOURS
85+
key=f"security:oidc:provider:{provider_name}:state:{state}", value=final_url, expires=KVTTL.TWO_HOURS
8886
)
8987

9088
if config.SETTINGS.dev.frontend_redirect_sso:
@@ -101,16 +99,16 @@ async def token(
10199
state: str,
102100
code: str,
103101
db: InfrahubDatabase = Depends(get_db),
104-
) -> models.UserToken:
102+
) -> models.UserTokenWithUrl:
105103
provider = config.SETTINGS.security.get_oidc_provider(provider=provider_name)
106104

107105
service: InfrahubServices = request.app.state.service
108106

109107
cache_key = f"security:oidc:provider:{provider_name}:state:{state}"
110-
stored_state = await service.cache.get(key=cache_key)
108+
stored_final_url = await service.cache.get(key=cache_key)
111109
await service.cache.delete(key=cache_key)
112110

113-
if state != stored_state:
111+
if not stored_final_url:
114112
raise ProcessingError(message="Invalid 'state' parameter")
115113

116114
token_data = {
@@ -148,7 +146,9 @@ async def token(
148146
max_age=config.SETTINGS.security.refresh_token_lifetime,
149147
)
150148

151-
return user_token
149+
return models.UserTokenWithUrl(
150+
access_token=user_token.access_token, refresh_token=user_token.refresh_token, final_url=stored_final_url
151+
)
152152

153153

154154
def _validate_response(response: httpx.Response) -> None:

backend/infrahub/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ class UserToken(BaseModel):
1313
refresh_token: str = Field(..., description="JWT refresh_token")
1414

1515

16+
class UserTokenWithUrl(UserToken):
17+
final_url: str = Field(..., description="The final url after logged in")
18+
19+
1620
class AccessTokenResponse(BaseModel):
1721
access_token: str = Field(..., description="JWT access_token")
1822

frontend/app/src/pages/auth-callback.tsx

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import { INFRAHUB_API_SERVER_URL } from "@/config/config";
22
import { useAuth } from "@/hooks/useAuth";
3+
import LoadingScreen from "@/screens/loading-screen/loading-screen";
34
import { configState } from "@/state/atoms/config.atom";
45
import { fetchUrl } from "@/utils/fetch";
56
import { useAtomValue } from "jotai";
6-
import { useEffect } from "react";
7+
import { useEffect, useState } from "react";
78
import { Navigate, useParams, useSearchParams } from "react-router-dom";
89

910
function AuthCallback() {
1011
const { protocol, provider } = useParams();
1112
const config = useAtomValue(configState);
1213
const [searchParams] = useSearchParams();
1314
const { isAuthenticated, setToken } = useAuth();
15+
const [redirectTo, setRedirectTo] = useState("/");
16+
1417
const code = searchParams.get("code");
1518
const state = searchParams.get("state");
1619
const error = searchParams.get("error");
@@ -26,6 +29,7 @@ function AuthCallback() {
2629
const { token_path } = currentAuthProvider;
2730
fetchUrl(`${INFRAHUB_API_SERVER_URL}${token_path}?code=${code}&state=${state}`).then(
2831
(result) => {
32+
setRedirectTo(result.final_url);
2933
setToken(result);
3034
}
3135
);
@@ -40,10 +44,14 @@ function AuthCallback() {
4044
}
4145

4246
if (isAuthenticated) {
43-
return <Navigate to="/" replace />;
47+
return <Navigate to={redirectTo} replace />;
4448
}
4549

46-
return null;
50+
return (
51+
<div className="w-screen h-screen flex items-center justify-center">
52+
<LoadingScreen />
53+
</div>
54+
);
4755
}
4856

4957
export const Component = AuthCallback;

frontend/app/src/screens/authentification/sign-in-sso-buttons.tsx

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
11
import { INFRAHUB_API_SERVER_URL } from "@/config/config";
22
import { Provider } from "@/state/atoms/config.atom";
33
import { Icon } from "@iconify-icon/react";
4+
import { useLocation } from "react-router-dom";
45

56
export const SignInWithSSOButtons = ({ providers }: { providers: Array<Provider> }) => {
7+
let location = useLocation();
8+
const redirectTo: string =
9+
(location.state?.from?.pathname || "/") + (location.state?.from?.search ?? "");
10+
611
return (
712
<div className="flex flex-col space-y-1 w-full">
813
{providers.map((provider) => (
9-
<ProviderButton key={provider.name + provider.protocol} provider={provider} />
14+
<ProviderButton
15+
key={provider.name + provider.protocol}
16+
provider={provider}
17+
redirectTo={redirectTo}
18+
/>
1019
))}
1120
</div>
1221
);
1322
};
1423

15-
export const ProviderButton = ({ provider }: { provider: Provider }) => {
24+
export const ProviderButton = ({
25+
provider,
26+
redirectTo = "/",
27+
}: { provider: Provider; redirectTo?: string }) => {
1628
return (
1729
<a
1830
className="h-9 px-4 py-2 inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium disabled:opacity-60 disabled:cursor-not-allowed border bg-custom-white shadow-sm hover:bg-gray-100"
19-
href={INFRAHUB_API_SERVER_URL + provider.authorize_path}
31+
href={`${INFRAHUB_API_SERVER_URL + provider.authorize_path}?final_url=${redirectTo}`}
2032
>
2133
<Icon icon={provider.icon} />
2234
<span className="ml-2">Sign in with {provider.display_label}</span>

0 commit comments

Comments
 (0)