Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion src/AuthContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import {
RedirectToOrgPageOptions,
RedirectToSetupSAMLPageOptions,
RedirectToSignupOptions,
OrgMemberInfoClass
} from "@propelauth/javascript"
import React, { useCallback, useEffect, useReducer } from "react"
import React, { useCallback, useEffect, useReducer, useState, useMemo } from "react"
import { loadOrgSelectionFromLocalStorage } from "./hooks/useActiveOrg"
import { useClientRef, useClientRefCallback } from "./useClientRef"

Expand Down Expand Up @@ -41,6 +42,11 @@ export interface InternalAuthState {
authUrl: string

tokens: Tokens

activeOrg: OrgMemberInfoClass | undefined
setActiveOrg: (orgId: string) => Promise<boolean>
removeActiveOrg: () => void

refreshAuthInfo: () => Promise<void>
defaultDisplayWhileLoading?: React.ReactElement
defaultDisplayIfLoggedOut?: React.ReactElement
Expand All @@ -56,6 +62,7 @@ export type AuthProviderProps = {
getActiveOrgFn?: () => string | null
children?: React.ReactNode
minSecondsBeforeRefresh?: number
useLocalStorageForActiveOrg?: boolean
}

export interface RequiredAuthProviderProps
Expand Down Expand Up @@ -101,6 +108,60 @@ function authInfoStateReducer(_state: AuthInfoState, action: AuthInfoStateAction
}
}

const ACTIVE_ORG_KEY = 'activeOrgId';

const getStoredActiveOrgId = (): string | null => {
try {
return localStorage.getItem(ACTIVE_ORG_KEY);
} catch (error) {
console.warn('Failed to read from localStorage:', error);
return null;
}
};

const setStoredActiveOrgId = (orgId: string): void => {
try {
localStorage.setItem(ACTIVE_ORG_KEY, orgId);
} catch (error) {
console.warn('Failed to write to localStorage:', error);
}
};

const removeStoredActiveOrgId = (): void => {
try {
localStorage.removeItem(ACTIVE_ORG_KEY);
} catch (error) {
console.warn('Failed to remove from localStorage:', error);
}
};

const useLocalStorageSync = (key: string): string | null => {
const [value, setValue] = useState<string | null>(() => {
try {
return localStorage.getItem(key);
} catch (error) {
console.warn('Failed to read from localStorage:', error);
return null;
}
});

useEffect(() => {
const handleStorageChange = (e: StorageEvent) => {
if (e.key === key) {
setValue(e.newValue);
}
};

window.addEventListener('storage', handleStorageChange);

return () => {
window.removeEventListener('storage', handleStorageChange);
};
}, [key]);

return value;
};

export const AuthProvider = (props: AuthProviderProps) => {
const {
authUrl,
Expand All @@ -109,8 +170,11 @@ export const AuthProvider = (props: AuthProviderProps) => {
children,
defaultDisplayWhileLoading,
defaultDisplayIfLoggedOut,
useLocalStorageForActiveOrg
} = props
const storedActiveOrgId = useLocalStorageSync(ACTIVE_ORG_KEY);
const [authInfoState, dispatch] = useReducer(authInfoStateReducer, initialAuthInfoState)
const [activeOrg, setActiveOrgState] = useState<OrgMemberInfoClass | undefined>();
const { clientRef, accessTokenChangeCounter } = useClientRef({
authUrl,
minSecondsBeforeRefresh,
Expand Down Expand Up @@ -142,13 +206,21 @@ export const AuthProvider = (props: AuthProviderProps) => {
}
}, [accessTokenChangeCounter])

// Re-render when stored active org is updated
useEffect(() => {
if (storedActiveOrgId && useLocalStorageForActiveOrg) {
setActiveOrg(storedActiveOrgId)
}
}, [storedActiveOrgId])

// Deprecation warning
useEffect(() => {
if (deprecatedGetActiveOrgFn) {
console.warn("The `getActiveOrgFn` prop is deprecated.")
}
}, [])


const logout = useClientRefCallback(clientRef, (client) => client.logout)
const redirectToLoginPage = useClientRefCallback(clientRef, (client) => client.redirectToLoginPage)
const redirectToSignupPage = useClientRefCallback(clientRef, (client) => client.redirectToSignupPage)
Expand Down Expand Up @@ -182,6 +254,56 @@ export const AuthProvider = (props: AuthProviderProps) => {
dispatch({ authInfo })
}, [dispatch])


const setActiveOrg = async (orgId: string) => {
const userClass = authInfoState?.authInfo?.userClass
if (!userClass) {
return false
}
const org = userClass.getOrg(orgId)

if (org) {
if (useLocalStorageForActiveOrg) {
setStoredActiveOrgId(orgId);
}
setActiveOrgState(org);
return true
} else {
if (useLocalStorageForActiveOrg) {
removeStoredActiveOrgId();
}
setActiveOrgState(undefined);
return false
}
};

const getActiveOrg = useMemo(() => {
const userClass = authInfoState?.authInfo?.userClass
if (!userClass) {
return undefined
}

if (!activeOrg && useLocalStorageForActiveOrg) {
const activeOrgIdFromLocalStorage = getStoredActiveOrgId()
if (activeOrgIdFromLocalStorage) {
return userClass.getOrg(activeOrgIdFromLocalStorage)
}
}
if (activeOrg) {
return userClass.getOrg(activeOrg.orgId)
}
return undefined
}, [activeOrg, authInfoState?.authInfo?.userClass])


function removeActiveOrg() {
if (useLocalStorageForActiveOrg) {
removeStoredActiveOrgId();
}
setActiveOrgState(undefined);
}


// TODO: Remove this, as both `getActiveOrgFn` and `loadOrgSelectionFromLocalStorage` are deprecated.
const deprecatedActiveOrgFn = deprecatedGetActiveOrgFn || loadOrgSelectionFromLocalStorage

Expand All @@ -206,6 +328,9 @@ export const AuthProvider = (props: AuthProviderProps) => {
getSetupSAMLPageUrl,
authUrl,
refreshAuthInfo,
activeOrg: getActiveOrg,
setActiveOrg,
removeActiveOrg,
tokens: {
getAccessTokenForOrg,
getAccessToken,
Expand Down
3 changes: 3 additions & 0 deletions src/AuthContextForTesting.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ export const AuthProviderForTesting = ({
getAccessTokenForOrg: getAccessTokenForOrg,
getAccessToken: () => Promise.resolve(userInformation?.accessToken ?? "ACCESS_TOKEN"),
},
activeOrg: undefined,
setActiveOrg: () => Promise.resolve(false),
removeActiveOrg: () => ""
}

return <AuthContext.Provider value={contextValue}>{children}</AuthContext.Provider>
Expand Down
22 changes: 20 additions & 2 deletions src/hooks/useAuthInfo.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AccessHelper, OrgHelper, User, UserClass } from "@propelauth/javascript"
import { AccessHelper, OrgHelper, OrgMemberInfoClass, User, UserClass } from "@propelauth/javascript"
import { useContext } from "react"
import { AuthContext, Tokens } from "../AuthContext"

Expand All @@ -15,6 +15,9 @@ export type UseAuthInfoLoading = {
refreshAuthInfo: () => Promise<void>
tokens: Tokens
accessTokenExpiresAtSeconds: undefined
activeOrg: undefined
setActiveOrg: undefined
removeActiveOrg: undefined
}

export type UseAuthInfoLoggedInProps = {
Expand All @@ -30,6 +33,9 @@ export type UseAuthInfoLoggedInProps = {
refreshAuthInfo: () => Promise<void>
tokens: Tokens
accessTokenExpiresAtSeconds: number
activeOrg: OrgMemberInfoClass | undefined
setActiveOrg: (orgId: string) => Promise<boolean>
removeActiveOrg: () => void
}

export type UseAuthInfoNotLoggedInProps = {
Expand All @@ -45,6 +51,9 @@ export type UseAuthInfoNotLoggedInProps = {
refreshAuthInfo: () => Promise<void>
tokens: Tokens
accessTokenExpiresAtSeconds: undefined
activeOrg: undefined
setActiveOrg: undefined
removeActiveOrg: () => void
}

export type UseAuthInfoProps = UseAuthInfoLoading | UseAuthInfoLoggedInProps | UseAuthInfoNotLoggedInProps
Expand All @@ -55,7 +64,7 @@ export function useAuthInfo(): UseAuthInfoProps {
throw new Error("useAuthInfo must be used within an AuthProvider or RequiredAuthProvider")
}

const { loading, authInfo, refreshAuthInfo, tokens } = context
const { loading, authInfo, refreshAuthInfo, tokens, activeOrg, setActiveOrg, removeActiveOrg } = context
if (loading) {
return {
loading: true,
Expand All @@ -70,6 +79,9 @@ export function useAuthInfo(): UseAuthInfoProps {
refreshAuthInfo,
tokens,
accessTokenExpiresAtSeconds: undefined,
activeOrg: undefined,
setActiveOrg: undefined,
removeActiveOrg: undefined
}
} else if (authInfo && authInfo.accessToken) {
return {
Expand All @@ -85,6 +97,9 @@ export function useAuthInfo(): UseAuthInfoProps {
refreshAuthInfo,
tokens,
accessTokenExpiresAtSeconds: authInfo.expiresAtSeconds,
activeOrg,
setActiveOrg,
removeActiveOrg,
}
}
return {
Expand All @@ -100,5 +115,8 @@ export function useAuthInfo(): UseAuthInfoProps {
refreshAuthInfo,
tokens,
accessTokenExpiresAtSeconds: undefined,
activeOrg: undefined,
setActiveOrg: undefined,
removeActiveOrg
}
}
16 changes: 14 additions & 2 deletions src/withAuthInfo.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AccessHelper, OrgHelper, User, UserClass } from "@propelauth/javascript"
import { AccessHelper, OrgHelper, User, UserClass, OrgMemberInfoClass } from "@propelauth/javascript"
import hoistNonReactStatics from "hoist-non-react-statics"
import React, { useContext } from "react"
import { Subtract } from "utility-types"
Expand All @@ -17,6 +17,9 @@ export type WithLoggedInAuthInfoProps = {
refreshAuthInfo: () => Promise<void>
tokens: Tokens
accessTokenExpiresAtSeconds: number
activeOrg: OrgMemberInfoClass | undefined
setActiveOrg: (orgId: string) => Promise<boolean>
removeActiveOrg: () => void
}

export type WithNotLoggedInAuthInfoProps = {
Expand All @@ -32,6 +35,9 @@ export type WithNotLoggedInAuthInfoProps = {
refreshAuthInfo: () => Promise<void>
tokens: Tokens
accessTokenExpiresAtSeconds: null
activeOrg: undefined
setActiveOrg: undefined
removeActiveOrg: () => void
}

export type WithAuthInfoProps = WithLoggedInAuthInfoProps | WithNotLoggedInAuthInfoProps
Expand All @@ -52,7 +58,7 @@ export function withAuthInfo<P extends WithAuthInfoProps>(
throw new Error("withAuthInfo must be used within an AuthProvider or RequiredAuthProvider")
}

const { loading, authInfo, defaultDisplayWhileLoading, refreshAuthInfo, tokens } = context
const { loading, authInfo, defaultDisplayWhileLoading, refreshAuthInfo, tokens, activeOrg, setActiveOrg, removeActiveOrg } = context

function displayLoading() {
if (args?.displayWhileLoading) {
Expand Down Expand Up @@ -80,6 +86,9 @@ export function withAuthInfo<P extends WithAuthInfoProps>(
refreshAuthInfo,
tokens,
accessTokenExpiresAtSeconds: authInfo.expiresAtSeconds,
activeOrg,
setActiveOrg,
removeActiveOrg
}
return <Component {...loggedInProps} />
} else {
Expand All @@ -97,6 +106,9 @@ export function withAuthInfo<P extends WithAuthInfoProps>(
refreshAuthInfo,
tokens,
accessTokenExpiresAtSeconds: null,
activeOrg: undefined,
setActiveOrg: undefined,
removeActiveOrg
}
return <Component {...notLoggedInProps} />
}
Expand Down
5 changes: 4 additions & 1 deletion src/withRequiredAuthInfo.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export function withRequiredAuthInfo<P extends WithLoggedInAuthInfoProps>(
throw new Error("withRequiredAuthInfo must be used within an AuthProvider or RequiredAuthProvider")
}

const { loading, authInfo, defaultDisplayIfLoggedOut, defaultDisplayWhileLoading, refreshAuthInfo, tokens } =
const { loading, authInfo, defaultDisplayIfLoggedOut, defaultDisplayWhileLoading, refreshAuthInfo, tokens, activeOrg, setActiveOrg, removeActiveOrg } =
context

function displayLoading() {
Expand Down Expand Up @@ -60,6 +60,9 @@ export function withRequiredAuthInfo<P extends WithLoggedInAuthInfoProps>(
refreshAuthInfo,
tokens,
accessTokenExpiresAtSeconds: authInfo.expiresAtSeconds,
activeOrg,
setActiveOrg,
removeActiveOrg,
}
return <Component {...loggedInProps} />
} else {
Expand Down