diff --git a/package.json b/package.json index cec739d..27ad889 100644 --- a/package.json +++ b/package.json @@ -5,7 +5,7 @@ "type": "git", "url": "https://github.com/PropelAuth/javascript" }, - "version": "2.0.23", + "version": "2.0.24", "keywords": [ "auth", "user", diff --git a/src/client.ts b/src/client.ts index b51d488..7bbab67 100644 --- a/src/client.ts +++ b/src/client.ts @@ -163,6 +163,11 @@ export interface IAuthClient { * Cleanup the auth client if you no longer need it. */ destroy(): void + + /** + * Returns the auth options with all default values filled in. + */ + getAuthOptions(): IResolvedAuthOptions } export interface IAuthOptions { @@ -195,12 +200,20 @@ export interface IAuthOptions { /** * If true, disables the token refresh on initial page load. * Can help reduce duplicate token refresh requests. - * + * * Default false */ skipInitialFetch?: boolean } +export interface IResolvedAuthOptions { + authUrl: string + enableBackgroundTokenRefresh: boolean + minSecondsBeforeRefresh: number + disableRefreshOnFocus: boolean + skipInitialFetch: boolean +} + interface AccessTokenActiveOrgMap { [orgId: string]: { accessToken: string @@ -218,6 +231,8 @@ interface ClientState { refreshInterval: number | null lastRefresh: number | null accessTokenActiveOrgMap: AccessTokenActiveOrgMap + pendingAuthRequest: Promise | null + pendingOrgAccessTokenRequests: Map> readonly authUrl: string } @@ -253,6 +268,8 @@ export function createClient(authOptions: IAuthOptions): IAuthClient { refreshInterval: null, lastRefresh: null, accessTokenActiveOrgMap: {}, + pendingAuthRequest: null, + pendingOrgAccessTokenRequests: new Map(), } // Helper functions @@ -331,23 +348,35 @@ export function createClient(authOptions: IAuthOptions): IAuthClient { } async function forceRefreshToken(returnCached: boolean): Promise { - try { - // Happy case, we fetch auth info and save it - const authenticationInfo = await runWithRetriesOnAnyError(() => - fetchAuthenticationInfo(clientState.authUrl) - ) - setAuthenticationInfoAndUpdateDownstream(authenticationInfo) - return authenticationInfo - } catch (e) { - // If there was an error, we sometimes still want to return the value we have cached - // (e.g. if we were prefetching), so in those cases we swallow the exception - if (returnCached) { - return clientState.authenticationInfo - } else { - setAuthenticationInfoAndUpdateDownstream(null) - throw e - } + // If there's already an in-flight request, return it to avoid duplicate fetches + if (clientState.pendingAuthRequest) { + return clientState.pendingAuthRequest } + + const request = (async () => { + try { + // Happy case, we fetch auth info and save it + const authenticationInfo = await runWithRetriesOnAnyError(() => + fetchAuthenticationInfo(clientState.authUrl) + ) + setAuthenticationInfoAndUpdateDownstream(authenticationInfo) + return authenticationInfo + } catch (e) { + // If there was an error, we sometimes still want to return the value we have cached + // (e.g. if we were prefetching), so in those cases we swallow the exception + if (returnCached) { + return clientState.authenticationInfo + } else { + setAuthenticationInfoAndUpdateDownstream(null) + throw e + } + } finally { + clientState.pendingAuthRequest = null + } + })() + + clientState.pendingAuthRequest = request + return request } const getSignupPageUrl = (options?: RedirectToSignupOptions) => { @@ -529,33 +558,47 @@ export function createClient(authOptions: IAuthOptions): IAuthClient { } } } - // Fetch the access token for the org ID and update. - try { - const authenticationInfo = await runWithRetriesOnAnyError(() => - fetchAuthenticationInfo(clientState.authUrl, orgId) - ) - if (!authenticationInfo) { - // Only null if 401 unauthorized. + + // Check for in-flight request for this org to avoid duplicate fetches + const pendingRequest = clientState.pendingOrgAccessTokenRequests.get(orgId) + if (pendingRequest) { + return pendingRequest + } + + // Create new request and store it + const request = (async (): Promise => { + try { + const authenticationInfo = await runWithRetriesOnAnyError(() => + fetchAuthenticationInfo(clientState.authUrl, orgId) + ) + if (!authenticationInfo) { + // Only null if 401 unauthorized. + return { + error: "user_not_in_org", + accessToken: null as never, + } + } + const { accessToken } = authenticationInfo + clientState.accessTokenActiveOrgMap[orgId] = { + accessToken, + fetchedAt: currentTimeSecs, + } + return { + accessToken, + error: undefined, + } + } catch (e) { return { - error: "user_not_in_org", + error: "unexpected_error", accessToken: null as never, } + } finally { + clientState.pendingOrgAccessTokenRequests.delete(orgId) } - const { accessToken } = authenticationInfo - clientState.accessTokenActiveOrgMap[orgId] = { - accessToken, - fetchedAt: currentTimeSecs, - } - return { - accessToken, - error: undefined, - } - } catch (e) { - return { - error: "unexpected_error", - accessToken: null as never, - } - } + })() + + clientState.pendingOrgAccessTokenRequests.set(orgId, request) + return request }, getSignupPageUrl(options?: RedirectToSignupOptions): string { @@ -626,6 +669,16 @@ export function createClient(authOptions: IAuthOptions): IAuthClient { clearInterval(clientState.refreshInterval) } }, + + getAuthOptions(): IResolvedAuthOptions { + return { + authUrl: clientState.authUrl, + enableBackgroundTokenRefresh: authOptions.enableBackgroundTokenRefresh!, + minSecondsBeforeRefresh: minSecondsBeforeRefresh, + disableRefreshOnFocus: authOptions.disableRefreshOnFocus ?? false, + skipInitialFetch: authOptions.skipInitialFetch ?? false, + } + }, } const onStorageChange = async function () { diff --git a/src/index.ts b/src/index.ts index 8a095ec..110b95b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,6 +5,7 @@ export type { AccessTokenForActiveOrg, IAuthClient, IAuthOptions, + IResolvedAuthOptions, RedirectToAccountOptions, RedirectToCreateOrgOptions, RedirectToLoginOptions, diff --git a/src/tests/index.test.ts b/src/tests/index.test.ts index c98e584..4cd9981 100644 --- a/src/tests/index.test.ts +++ b/src/tests/index.test.ts @@ -1,10 +1,10 @@ /** * @jest-environment jsdom */ +import { DEFAULT_RETRIES } from "../fetch_retries" import { createClient } from "../index" import { OrgIdToOrgMemberInfo } from "../org" import { ok, ResponseStatus, setupMockFetch, UnauthorizedResponse, UnknownErrorResponse } from "./mockfetch.test" -import {DEFAULT_RETRIES} from "../fetch_retries"; const INITIAL_TIME_MILLIS = 1619743452595 const INITIAL_TIME_SECONDS = INITIAL_TIME_MILLIS / 1000 @@ -12,7 +12,7 @@ const INITIAL_TIME_SECONDS = INITIAL_TIME_MILLIS / 1000 beforeAll(() => { jest.useFakeTimers("modern") // @ts-ignore - global.setTimeout = jest.fn(cb => cb()); + global.setTimeout = jest.fn((cb) => cb()) }) beforeEach(() => { @@ -69,6 +69,47 @@ afterAll(() => { jest.useRealTimers() }) +test("getAuthOptions returns defaults when no options provided", () => { + let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false }) + + const options = client.getAuthOptions() + + expect(options.authUrl).toBe("https://www.example.com") + expect(options.enableBackgroundTokenRefresh).toBe(false) + expect(options.minSecondsBeforeRefresh).toBe(120) + expect(options.disableRefreshOnFocus).toBe(false) + expect(options.skipInitialFetch).toBe(false) +}) + +test("getAuthOptions returns user-provided values", () => { + let client = createClient({ + authUrl: "https://www.example.com", + enableBackgroundTokenRefresh: false, + minSecondsBeforeRefresh: 300, + disableRefreshOnFocus: true, + skipInitialFetch: true, + }) + + const options = client.getAuthOptions() + + expect(options.authUrl).toBe("https://www.example.com") + expect(options.enableBackgroundTokenRefresh).toBe(false) + expect(options.minSecondsBeforeRefresh).toBe(300) + expect(options.disableRefreshOnFocus).toBe(true) + expect(options.skipInitialFetch).toBe(true) +}) + +test("getAuthOptions returns normalized authUrl", () => { + let client = createClient({ + authUrl: "https://www.example.com/path/to/something", + enableBackgroundTokenRefresh: false, + }) + + const options = client.getAuthOptions() + + expect(options.authUrl).toBe("https://www.example.com") +}) + test("cannot create client without auth url origin", () => { expect(() => { createClient({ authUrl: "" }) @@ -120,7 +161,7 @@ test("client parses org information correctly", async () => { user_role: "Owner", inherited_user_roles_plus_current_role: ["Owner", "Admin", "Member"], user_permissions: ["View", "Edit", "Delete", "ManageAccess"], - legacy_org_id: "ce126279-48a2-4fc4-a9e5-da62a33d1b11" + legacy_org_id: "ce126279-48a2-4fc4-a9e5-da62a33d1b11", }, "fcdb21f0-b1b6-426f-b83c-6cf4b903d737": { org_id: "fcdb21f0-b1b6-426f-b83c-6cf4b903d737", @@ -147,7 +188,7 @@ test("client parses org information correctly", async () => { userAssignedRole: "Owner", userInheritedRolesPlusCurrentRole: ["Owner", "Admin", "Member"], userPermissions: ["View", "Edit", "Delete", "ManageAccess"], - legacyOrgId: "ce126279-48a2-4fc4-a9e5-da62a33d1b11" + legacyOrgId: "ce126279-48a2-4fc4-a9e5-da62a33d1b11", }, "fcdb21f0-b1b6-426f-b83c-6cf4b903d737": { orgId: "fcdb21f0-b1b6-426f-b83c-6cf4b903d737", @@ -185,6 +226,105 @@ test("client returns null on a 401", async () => { expectCorrectEndpointWasHit(mockFetch, "https://www.example.com/api/v1/refresh_token") }) +test("after concurrent getAuthenticationInfoOrNull calls complete, a new call makes a new HTTP request", async () => { + const { mockFetch } = setupMockFetchThatReturnsAccessToken() + let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false }) + + // Make concurrent calls and wait for them to complete + await Promise.all([ + client.getAuthenticationInfoOrNull(true), + client.getAuthenticationInfoOrNull(true), + client.getAuthenticationInfoOrNull(true), + ]) + + // First batch should have made 1 request + expect(mockFetch).toBeCalledTimes(1) + + // Now make another call - this should make a new HTTP request since the previous one completed + await client.getAuthenticationInfoOrNull(true) + + // Should now have 2 total requests + expect(mockFetch).toBeCalledTimes(2) +}) + +test("after concurrent getAccessTokenForOrg calls complete, a new call makes a new HTTP request", async () => { + const { mockFetch } = setupMockFetchThatReturnsAccessToken() + let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false }) + + const orgId = "test-org-123" + + // Make concurrent calls and wait for them to complete + await Promise.all([ + client.getAccessTokenForOrg(orgId), + client.getAccessTokenForOrg(orgId), + client.getAccessTokenForOrg(orgId), + ]) + + // First batch should have made 1 request + expect(mockFetch).toBeCalledTimes(1) + + // Advance time past the cache expiration so a new request is needed + const newTime = INITIAL_TIME_MILLIS + ACTIVE_ORG_ACCESS_TOKEN_REFRESH_EXPIRATION_SECONDS * 1000 + 1000 + jest.setSystemTime(newTime) + + // Now make another call - this should make a new HTTP request + await client.getAccessTokenForOrg(orgId) + + // Should now have 2 total requests + expect(mockFetch).toBeCalledTimes(2) +}) + +// Constant needed for the test above +const ACTIVE_ORG_ACCESS_TOKEN_REFRESH_EXPIRATION_SECONDS = 60 * 5 + +test("concurrent calls to getAccessTokenForOrg make only one HTTP request per org", async () => { + const { expectedAccessToken, mockFetch } = setupMockFetchThatReturnsAccessToken() + let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false }) + + const orgId = "test-org-123" + + // Make 3 concurrent calls for the same org - these should all share the same in-flight request + const promises = [ + client.getAccessTokenForOrg(orgId), + client.getAccessTokenForOrg(orgId), + client.getAccessTokenForOrg(orgId), + ] + const results = await Promise.all(promises) + + // All should return the same result + expect(results[0].accessToken).toBe(expectedAccessToken) + expect(results[1].accessToken).toBe(expectedAccessToken) + expect(results[2].accessToken).toBe(expectedAccessToken) + + // Only one HTTP request should have been made + expect(mockFetch).toBeCalledTimes(1) + expect(mockFetch).toHaveBeenCalledWith( + `https://www.example.com/api/v1/refresh_token?active_org_id=${orgId}`, + expect.objectContaining({ method: "GET" }) + ) +}) + +test("concurrent calls to getAuthenticationInfoOrNull make only one HTTP request", async () => { + const { expectedAccessToken, mockFetch } = setupMockFetchThatReturnsAccessToken() + let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false }) + + // Make 3 concurrent calls - these should all share the same in-flight request + const promises = [ + client.getAuthenticationInfoOrNull(), + client.getAuthenticationInfoOrNull(), + client.getAuthenticationInfoOrNull(), + ] + const results = await Promise.all(promises) + + // All should return the same result + expect(results[0]?.accessToken).toBe(expectedAccessToken) + expect(results[1]?.accessToken).toBe(expectedAccessToken) + expect(results[2]?.accessToken).toBe(expectedAccessToken) + + // Only one HTTP request should have been made + expectCorrectEndpointWasHit(mockFetch, "https://www.example.com/api/v1/refresh_token", 1) +}) + test("repeated calls to getAuthenticationInfo do NOT make multiple http requests if the expiration is far in the future", async () => { const { expectedAccessToken, mockFetch } = setupMockFetchThatReturnsAccessToken() let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false }) @@ -421,8 +561,8 @@ test("if a new client is created and cannot get an access token, it should trigg const post401AuthenticationInfo0 = await client0.getAuthenticationInfoOrNull() expect(post401AuthenticationInfo0).toBeNull() - // Called 3 times because client0 ends up making 2 requests, 1 when client1 triggers a logout event and 1 when asked - expectCorrectEndpointWasHit(logoutMockFetch, "https://www.example.com/api/v1/refresh_token", 3) + // Called 2 times: 1 from client1, 1 from client0 (the storage-triggered request is reused by the explicit call) + expectCorrectEndpointWasHit(logoutMockFetch, "https://www.example.com/api/v1/refresh_token", 2) }) function expectCorrectEndpointWasHit(mockFetch: any, correctRefreshUrl: string, numSendTimes = 1, method = "get") {