diff --git a/src/AuthContext.tsx b/src/AuthContext.tsx index 0a8ef77..f2c0e9c 100644 --- a/src/AuthContext.tsx +++ b/src/AuthContext.tsx @@ -7,8 +7,9 @@ import { RedirectToOrgPageOptions, RedirectToSetupSAMLPageOptions, RedirectToSignupOptions, + OrgMemberInfoClass } from "@propelauth/javascript" -import React, { useCallback, useEffect, useReducer } from "react" +import React, { useCallback, useEffect, useReducer, useState, useMemo } from "react" import { loadOrgSelectionFromLocalStorage } from "./hooks/useActiveOrg" import { useClientRef, useClientRefCallback } from "./useClientRef" @@ -41,6 +42,11 @@ export interface InternalAuthState { authUrl: string tokens: Tokens + + activeOrg: OrgMemberInfoClass | undefined + setActiveOrg: (orgId: string) => Promise + removeActiveOrg: () => void + refreshAuthInfo: () => Promise defaultDisplayWhileLoading?: React.ReactElement defaultDisplayIfLoggedOut?: React.ReactElement @@ -56,6 +62,7 @@ export type AuthProviderProps = { getActiveOrgFn?: () => string | null children?: React.ReactNode minSecondsBeforeRefresh?: number + useLocalStorageForActiveOrg?: boolean } export interface RequiredAuthProviderProps @@ -101,6 +108,60 @@ function authInfoStateReducer(_state: AuthInfoState, action: AuthInfoStateAction } } +const ACTIVE_ORG_KEY = 'activeOrgId'; + +const getStoredActiveOrgId = (): string | null => { + try { + return localStorage.getItem(ACTIVE_ORG_KEY); + } catch (error) { + console.warn('Failed to read from localStorage:', error); + return null; + } +}; + +const setStoredActiveOrgId = (orgId: string): void => { + try { + localStorage.setItem(ACTIVE_ORG_KEY, orgId); + } catch (error) { + console.warn('Failed to write to localStorage:', error); + } +}; + +const removeStoredActiveOrgId = (): void => { + try { + localStorage.removeItem(ACTIVE_ORG_KEY); + } catch (error) { + console.warn('Failed to remove from localStorage:', error); + } +}; + +const useLocalStorageSync = (key: string): string | null => { + const [value, setValue] = useState(() => { + try { + return localStorage.getItem(key); + } catch (error) { + console.warn('Failed to read from localStorage:', error); + return null; + } + }); + + useEffect(() => { + const handleStorageChange = (e: StorageEvent) => { + if (e.key === key) { + setValue(e.newValue); + } + }; + + window.addEventListener('storage', handleStorageChange); + + return () => { + window.removeEventListener('storage', handleStorageChange); + }; + }, [key]); + + return value; +}; + export const AuthProvider = (props: AuthProviderProps) => { const { authUrl, @@ -109,8 +170,11 @@ export const AuthProvider = (props: AuthProviderProps) => { children, defaultDisplayWhileLoading, defaultDisplayIfLoggedOut, + useLocalStorageForActiveOrg } = props + const storedActiveOrgId = useLocalStorageSync(ACTIVE_ORG_KEY); const [authInfoState, dispatch] = useReducer(authInfoStateReducer, initialAuthInfoState) + const [activeOrg, setActiveOrgState] = useState(); const { clientRef, accessTokenChangeCounter } = useClientRef({ authUrl, minSecondsBeforeRefresh, @@ -142,6 +206,13 @@ export const AuthProvider = (props: AuthProviderProps) => { } }, [accessTokenChangeCounter]) + // Re-render when stored active org is updated + useEffect(() => { + if (storedActiveOrgId && useLocalStorageForActiveOrg) { + setActiveOrg(storedActiveOrgId) + } + }, [storedActiveOrgId]) + // Deprecation warning useEffect(() => { if (deprecatedGetActiveOrgFn) { @@ -149,6 +220,7 @@ export const AuthProvider = (props: AuthProviderProps) => { } }, []) + const logout = useClientRefCallback(clientRef, (client) => client.logout) const redirectToLoginPage = useClientRefCallback(clientRef, (client) => client.redirectToLoginPage) const redirectToSignupPage = useClientRefCallback(clientRef, (client) => client.redirectToSignupPage) @@ -182,6 +254,56 @@ export const AuthProvider = (props: AuthProviderProps) => { dispatch({ authInfo }) }, [dispatch]) + + const setActiveOrg = async (orgId: string) => { + const userClass = authInfoState?.authInfo?.userClass + if (!userClass) { + return false + } + const org = userClass.getOrg(orgId) + + if (org) { + if (useLocalStorageForActiveOrg) { + setStoredActiveOrgId(orgId); + } + setActiveOrgState(org); + return true + } else { + if (useLocalStorageForActiveOrg) { + removeStoredActiveOrgId(); + } + setActiveOrgState(undefined); + return false + } + }; + + const getActiveOrg = useMemo(() => { + const userClass = authInfoState?.authInfo?.userClass + if (!userClass) { + return undefined + } + + if (!activeOrg && useLocalStorageForActiveOrg) { + const activeOrgIdFromLocalStorage = getStoredActiveOrgId() + if (activeOrgIdFromLocalStorage) { + return userClass.getOrg(activeOrgIdFromLocalStorage) + } + } + if (activeOrg) { + return userClass.getOrg(activeOrg.orgId) + } + return undefined + }, [activeOrg, authInfoState?.authInfo?.userClass]) + + + function removeActiveOrg() { + if (useLocalStorageForActiveOrg) { + removeStoredActiveOrgId(); + } + setActiveOrgState(undefined); + } + + // TODO: Remove this, as both `getActiveOrgFn` and `loadOrgSelectionFromLocalStorage` are deprecated. const deprecatedActiveOrgFn = deprecatedGetActiveOrgFn || loadOrgSelectionFromLocalStorage @@ -206,6 +328,9 @@ export const AuthProvider = (props: AuthProviderProps) => { getSetupSAMLPageUrl, authUrl, refreshAuthInfo, + activeOrg: getActiveOrg, + setActiveOrg, + removeActiveOrg, tokens: { getAccessTokenForOrg, getAccessToken, diff --git a/src/AuthContextForTesting.tsx b/src/AuthContextForTesting.tsx index 980e132..8218088 100644 --- a/src/AuthContextForTesting.tsx +++ b/src/AuthContextForTesting.tsx @@ -79,6 +79,9 @@ export const AuthProviderForTesting = ({ getAccessTokenForOrg: getAccessTokenForOrg, getAccessToken: () => Promise.resolve(userInformation?.accessToken ?? "ACCESS_TOKEN"), }, + activeOrg: undefined, + setActiveOrg: () => Promise.resolve(false), + removeActiveOrg: () => "" } return {children} diff --git a/src/hooks/useAuthInfo.ts b/src/hooks/useAuthInfo.ts index 576f4de..aa8458e 100644 --- a/src/hooks/useAuthInfo.ts +++ b/src/hooks/useAuthInfo.ts @@ -1,4 +1,4 @@ -import { AccessHelper, OrgHelper, User, UserClass } from "@propelauth/javascript" +import { AccessHelper, OrgHelper, OrgMemberInfoClass, User, UserClass } from "@propelauth/javascript" import { useContext } from "react" import { AuthContext, Tokens } from "../AuthContext" @@ -15,6 +15,9 @@ export type UseAuthInfoLoading = { refreshAuthInfo: () => Promise tokens: Tokens accessTokenExpiresAtSeconds: undefined + activeOrg: undefined + setActiveOrg: undefined + removeActiveOrg: undefined } export type UseAuthInfoLoggedInProps = { @@ -30,6 +33,9 @@ export type UseAuthInfoLoggedInProps = { refreshAuthInfo: () => Promise tokens: Tokens accessTokenExpiresAtSeconds: number + activeOrg: OrgMemberInfoClass | undefined + setActiveOrg: (orgId: string) => Promise + removeActiveOrg: () => void } export type UseAuthInfoNotLoggedInProps = { @@ -45,6 +51,9 @@ export type UseAuthInfoNotLoggedInProps = { refreshAuthInfo: () => Promise tokens: Tokens accessTokenExpiresAtSeconds: undefined + activeOrg: undefined + setActiveOrg: undefined + removeActiveOrg: () => void } export type UseAuthInfoProps = UseAuthInfoLoading | UseAuthInfoLoggedInProps | UseAuthInfoNotLoggedInProps @@ -55,7 +64,7 @@ export function useAuthInfo(): UseAuthInfoProps { throw new Error("useAuthInfo must be used within an AuthProvider or RequiredAuthProvider") } - const { loading, authInfo, refreshAuthInfo, tokens } = context + const { loading, authInfo, refreshAuthInfo, tokens, activeOrg, setActiveOrg, removeActiveOrg } = context if (loading) { return { loading: true, @@ -70,6 +79,9 @@ export function useAuthInfo(): UseAuthInfoProps { refreshAuthInfo, tokens, accessTokenExpiresAtSeconds: undefined, + activeOrg: undefined, + setActiveOrg: undefined, + removeActiveOrg: undefined } } else if (authInfo && authInfo.accessToken) { return { @@ -85,6 +97,9 @@ export function useAuthInfo(): UseAuthInfoProps { refreshAuthInfo, tokens, accessTokenExpiresAtSeconds: authInfo.expiresAtSeconds, + activeOrg, + setActiveOrg, + removeActiveOrg, } } return { @@ -100,5 +115,8 @@ export function useAuthInfo(): UseAuthInfoProps { refreshAuthInfo, tokens, accessTokenExpiresAtSeconds: undefined, + activeOrg: undefined, + setActiveOrg: undefined, + removeActiveOrg } } diff --git a/src/withAuthInfo.tsx b/src/withAuthInfo.tsx index 7585057..6b54ce9 100644 --- a/src/withAuthInfo.tsx +++ b/src/withAuthInfo.tsx @@ -1,4 +1,4 @@ -import { AccessHelper, OrgHelper, User, UserClass } from "@propelauth/javascript" +import { AccessHelper, OrgHelper, User, UserClass, OrgMemberInfoClass } from "@propelauth/javascript" import hoistNonReactStatics from "hoist-non-react-statics" import React, { useContext } from "react" import { Subtract } from "utility-types" @@ -17,6 +17,9 @@ export type WithLoggedInAuthInfoProps = { refreshAuthInfo: () => Promise tokens: Tokens accessTokenExpiresAtSeconds: number + activeOrg: OrgMemberInfoClass | undefined + setActiveOrg: (orgId: string) => Promise + removeActiveOrg: () => void } export type WithNotLoggedInAuthInfoProps = { @@ -32,6 +35,9 @@ export type WithNotLoggedInAuthInfoProps = { refreshAuthInfo: () => Promise tokens: Tokens accessTokenExpiresAtSeconds: null + activeOrg: undefined + setActiveOrg: undefined + removeActiveOrg: () => void } export type WithAuthInfoProps = WithLoggedInAuthInfoProps | WithNotLoggedInAuthInfoProps @@ -52,7 +58,7 @@ export function withAuthInfo

( throw new Error("withAuthInfo must be used within an AuthProvider or RequiredAuthProvider") } - const { loading, authInfo, defaultDisplayWhileLoading, refreshAuthInfo, tokens } = context + const { loading, authInfo, defaultDisplayWhileLoading, refreshAuthInfo, tokens, activeOrg, setActiveOrg, removeActiveOrg } = context function displayLoading() { if (args?.displayWhileLoading) { @@ -80,6 +86,9 @@ export function withAuthInfo

( refreshAuthInfo, tokens, accessTokenExpiresAtSeconds: authInfo.expiresAtSeconds, + activeOrg, + setActiveOrg, + removeActiveOrg } return } else { @@ -97,6 +106,9 @@ export function withAuthInfo

( refreshAuthInfo, tokens, accessTokenExpiresAtSeconds: null, + activeOrg: undefined, + setActiveOrg: undefined, + removeActiveOrg } return } diff --git a/src/withRequiredAuthInfo.tsx b/src/withRequiredAuthInfo.tsx index eba4f63..1c6af25 100644 --- a/src/withRequiredAuthInfo.tsx +++ b/src/withRequiredAuthInfo.tsx @@ -22,7 +22,7 @@ export function withRequiredAuthInfo

( throw new Error("withRequiredAuthInfo must be used within an AuthProvider or RequiredAuthProvider") } - const { loading, authInfo, defaultDisplayIfLoggedOut, defaultDisplayWhileLoading, refreshAuthInfo, tokens } = + const { loading, authInfo, defaultDisplayIfLoggedOut, defaultDisplayWhileLoading, refreshAuthInfo, tokens, activeOrg, setActiveOrg, removeActiveOrg } = context function displayLoading() { @@ -60,6 +60,9 @@ export function withRequiredAuthInfo

( refreshAuthInfo, tokens, accessTokenExpiresAtSeconds: authInfo.expiresAtSeconds, + activeOrg, + setActiveOrg, + removeActiveOrg, } return } else {