Skip to content

Commit ca167d2

Browse files
Deduplicate in-flight requests for refreshes and org-scoped refreshes
1 parent 62fce6f commit ca167d2

File tree

2 files changed

+178
-46
lines changed

2 files changed

+178
-46
lines changed

src/client.ts

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ interface ClientState {
231231
refreshInterval: number | null
232232
lastRefresh: number | null
233233
accessTokenActiveOrgMap: AccessTokenActiveOrgMap
234+
pendingAuthRequest: Promise<AuthenticationInfo | null> | null
235+
pendingOrgAccessTokenRequests: Map<string, Promise<AccessTokenForActiveOrg>>
234236
readonly authUrl: string
235237
}
236238

@@ -266,6 +268,8 @@ export function createClient(authOptions: IAuthOptions): IAuthClient {
266268
refreshInterval: null,
267269
lastRefresh: null,
268270
accessTokenActiveOrgMap: {},
271+
pendingAuthRequest: null,
272+
pendingOrgAccessTokenRequests: new Map(),
269273
}
270274

271275
// Helper functions
@@ -344,23 +348,35 @@ export function createClient(authOptions: IAuthOptions): IAuthClient {
344348
}
345349

346350
async function forceRefreshToken(returnCached: boolean): Promise<AuthenticationInfo | null> {
347-
try {
348-
// Happy case, we fetch auth info and save it
349-
const authenticationInfo = await runWithRetriesOnAnyError(() =>
350-
fetchAuthenticationInfo(clientState.authUrl)
351-
)
352-
setAuthenticationInfoAndUpdateDownstream(authenticationInfo)
353-
return authenticationInfo
354-
} catch (e) {
355-
// If there was an error, we sometimes still want to return the value we have cached
356-
// (e.g. if we were prefetching), so in those cases we swallow the exception
357-
if (returnCached) {
358-
return clientState.authenticationInfo
359-
} else {
360-
setAuthenticationInfoAndUpdateDownstream(null)
361-
throw e
362-
}
351+
// If there's already an in-flight request, return it to avoid duplicate fetches
352+
if (clientState.pendingAuthRequest) {
353+
return clientState.pendingAuthRequest
363354
}
355+
356+
const request = (async () => {
357+
try {
358+
// Happy case, we fetch auth info and save it
359+
const authenticationInfo = await runWithRetriesOnAnyError(() =>
360+
fetchAuthenticationInfo(clientState.authUrl)
361+
)
362+
setAuthenticationInfoAndUpdateDownstream(authenticationInfo)
363+
return authenticationInfo
364+
} catch (e) {
365+
// If there was an error, we sometimes still want to return the value we have cached
366+
// (e.g. if we were prefetching), so in those cases we swallow the exception
367+
if (returnCached) {
368+
return clientState.authenticationInfo
369+
} else {
370+
setAuthenticationInfoAndUpdateDownstream(null)
371+
throw e
372+
}
373+
} finally {
374+
clientState.pendingAuthRequest = null
375+
}
376+
})()
377+
378+
clientState.pendingAuthRequest = request
379+
return request
364380
}
365381

366382
const getSignupPageUrl = (options?: RedirectToSignupOptions) => {
@@ -542,33 +558,47 @@ export function createClient(authOptions: IAuthOptions): IAuthClient {
542558
}
543559
}
544560
}
545-
// Fetch the access token for the org ID and update.
546-
try {
547-
const authenticationInfo = await runWithRetriesOnAnyError(() =>
548-
fetchAuthenticationInfo(clientState.authUrl, orgId)
549-
)
550-
if (!authenticationInfo) {
551-
// Only null if 401 unauthorized.
561+
562+
// Check for in-flight request for this org to avoid duplicate fetches
563+
const pendingRequest = clientState.pendingOrgAccessTokenRequests.get(orgId)
564+
if (pendingRequest) {
565+
return pendingRequest
566+
}
567+
568+
// Create new request and store it
569+
const request = (async (): Promise<AccessTokenForActiveOrg> => {
570+
try {
571+
const authenticationInfo = await runWithRetriesOnAnyError(() =>
572+
fetchAuthenticationInfo(clientState.authUrl, orgId)
573+
)
574+
if (!authenticationInfo) {
575+
// Only null if 401 unauthorized.
576+
return {
577+
error: "user_not_in_org",
578+
accessToken: null as never,
579+
}
580+
}
581+
const { accessToken } = authenticationInfo
582+
clientState.accessTokenActiveOrgMap[orgId] = {
583+
accessToken,
584+
fetchedAt: currentTimeSecs,
585+
}
586+
return {
587+
accessToken,
588+
error: undefined,
589+
}
590+
} catch (e) {
552591
return {
553-
error: "user_not_in_org",
592+
error: "unexpected_error",
554593
accessToken: null as never,
555594
}
595+
} finally {
596+
clientState.pendingOrgAccessTokenRequests.delete(orgId)
556597
}
557-
const { accessToken } = authenticationInfo
558-
clientState.accessTokenActiveOrgMap[orgId] = {
559-
accessToken,
560-
fetchedAt: currentTimeSecs,
561-
}
562-
return {
563-
accessToken,
564-
error: undefined,
565-
}
566-
} catch (e) {
567-
return {
568-
error: "unexpected_error",
569-
accessToken: null as never,
570-
}
571-
}
598+
})()
599+
600+
clientState.pendingOrgAccessTokenRequests.set(orgId, request)
601+
return request
572602
},
573603

574604
getSignupPageUrl(options?: RedirectToSignupOptions): string {

src/tests/index.test.ts

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
/**
22
* @jest-environment jsdom
33
*/
4+
import { DEFAULT_RETRIES } from "../fetch_retries"
45
import { createClient } from "../index"
56
import { OrgIdToOrgMemberInfo } from "../org"
67
import { ok, ResponseStatus, setupMockFetch, UnauthorizedResponse, UnknownErrorResponse } from "./mockfetch.test"
7-
import {DEFAULT_RETRIES} from "../fetch_retries";
88

99
const INITIAL_TIME_MILLIS = 1619743452595
1010
const INITIAL_TIME_SECONDS = INITIAL_TIME_MILLIS / 1000
1111

1212
beforeAll(() => {
1313
jest.useFakeTimers("modern")
1414
// @ts-ignore
15-
global.setTimeout = jest.fn(cb => cb());
15+
global.setTimeout = jest.fn((cb) => cb())
1616
})
1717

1818
beforeEach(() => {
@@ -100,7 +100,10 @@ test("getAuthOptions returns user-provided values", () => {
100100
})
101101

102102
test("getAuthOptions returns normalized authUrl", () => {
103-
let client = createClient({ authUrl: "https://www.example.com/path/to/something", enableBackgroundTokenRefresh: false })
103+
let client = createClient({
104+
authUrl: "https://www.example.com/path/to/something",
105+
enableBackgroundTokenRefresh: false,
106+
})
104107

105108
const options = client.getAuthOptions()
106109

@@ -158,7 +161,7 @@ test("client parses org information correctly", async () => {
158161
user_role: "Owner",
159162
inherited_user_roles_plus_current_role: ["Owner", "Admin", "Member"],
160163
user_permissions: ["View", "Edit", "Delete", "ManageAccess"],
161-
legacy_org_id: "ce126279-48a2-4fc4-a9e5-da62a33d1b11"
164+
legacy_org_id: "ce126279-48a2-4fc4-a9e5-da62a33d1b11",
162165
},
163166
"fcdb21f0-b1b6-426f-b83c-6cf4b903d737": {
164167
org_id: "fcdb21f0-b1b6-426f-b83c-6cf4b903d737",
@@ -185,7 +188,7 @@ test("client parses org information correctly", async () => {
185188
userAssignedRole: "Owner",
186189
userInheritedRolesPlusCurrentRole: ["Owner", "Admin", "Member"],
187190
userPermissions: ["View", "Edit", "Delete", "ManageAccess"],
188-
legacyOrgId: "ce126279-48a2-4fc4-a9e5-da62a33d1b11"
191+
legacyOrgId: "ce126279-48a2-4fc4-a9e5-da62a33d1b11",
189192
},
190193
"fcdb21f0-b1b6-426f-b83c-6cf4b903d737": {
191194
orgId: "fcdb21f0-b1b6-426f-b83c-6cf4b903d737",
@@ -223,6 +226,105 @@ test("client returns null on a 401", async () => {
223226
expectCorrectEndpointWasHit(mockFetch, "https://www.example.com/api/v1/refresh_token")
224227
})
225228

229+
test("after concurrent getAuthenticationInfoOrNull calls complete, a new call makes a new HTTP request", async () => {
230+
const { mockFetch } = setupMockFetchThatReturnsAccessToken()
231+
let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false })
232+
233+
// Make concurrent calls and wait for them to complete
234+
await Promise.all([
235+
client.getAuthenticationInfoOrNull(true),
236+
client.getAuthenticationInfoOrNull(true),
237+
client.getAuthenticationInfoOrNull(true),
238+
])
239+
240+
// First batch should have made 1 request
241+
expect(mockFetch).toBeCalledTimes(1)
242+
243+
// Now make another call - this should make a new HTTP request since the previous one completed
244+
await client.getAuthenticationInfoOrNull(true)
245+
246+
// Should now have 2 total requests
247+
expect(mockFetch).toBeCalledTimes(2)
248+
})
249+
250+
test("after concurrent getAccessTokenForOrg calls complete, a new call makes a new HTTP request", async () => {
251+
const { mockFetch } = setupMockFetchThatReturnsAccessToken()
252+
let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false })
253+
254+
const orgId = "test-org-123"
255+
256+
// Make concurrent calls and wait for them to complete
257+
await Promise.all([
258+
client.getAccessTokenForOrg(orgId),
259+
client.getAccessTokenForOrg(orgId),
260+
client.getAccessTokenForOrg(orgId),
261+
])
262+
263+
// First batch should have made 1 request
264+
expect(mockFetch).toBeCalledTimes(1)
265+
266+
// Advance time past the cache expiration so a new request is needed
267+
const newTime = INITIAL_TIME_MILLIS + ACTIVE_ORG_ACCESS_TOKEN_REFRESH_EXPIRATION_SECONDS * 1000 + 1000
268+
jest.setSystemTime(newTime)
269+
270+
// Now make another call - this should make a new HTTP request
271+
await client.getAccessTokenForOrg(orgId)
272+
273+
// Should now have 2 total requests
274+
expect(mockFetch).toBeCalledTimes(2)
275+
})
276+
277+
// Constant needed for the test above
278+
const ACTIVE_ORG_ACCESS_TOKEN_REFRESH_EXPIRATION_SECONDS = 60 * 5
279+
280+
test("concurrent calls to getAccessTokenForOrg make only one HTTP request per org", async () => {
281+
const { expectedAccessToken, mockFetch } = setupMockFetchThatReturnsAccessToken()
282+
let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false })
283+
284+
const orgId = "test-org-123"
285+
286+
// Make 3 concurrent calls for the same org - these should all share the same in-flight request
287+
const promises = [
288+
client.getAccessTokenForOrg(orgId),
289+
client.getAccessTokenForOrg(orgId),
290+
client.getAccessTokenForOrg(orgId),
291+
]
292+
const results = await Promise.all(promises)
293+
294+
// All should return the same result
295+
expect(results[0].accessToken).toBe(expectedAccessToken)
296+
expect(results[1].accessToken).toBe(expectedAccessToken)
297+
expect(results[2].accessToken).toBe(expectedAccessToken)
298+
299+
// Only one HTTP request should have been made
300+
expect(mockFetch).toBeCalledTimes(1)
301+
expect(mockFetch).toHaveBeenCalledWith(
302+
`https://www.example.com/api/v1/refresh_token?active_org_id=${orgId}`,
303+
expect.objectContaining({ method: "GET" })
304+
)
305+
})
306+
307+
test("concurrent calls to getAuthenticationInfoOrNull make only one HTTP request", async () => {
308+
const { expectedAccessToken, mockFetch } = setupMockFetchThatReturnsAccessToken()
309+
let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false })
310+
311+
// Make 3 concurrent calls - these should all share the same in-flight request
312+
const promises = [
313+
client.getAuthenticationInfoOrNull(),
314+
client.getAuthenticationInfoOrNull(),
315+
client.getAuthenticationInfoOrNull(),
316+
]
317+
const results = await Promise.all(promises)
318+
319+
// All should return the same result
320+
expect(results[0]?.accessToken).toBe(expectedAccessToken)
321+
expect(results[1]?.accessToken).toBe(expectedAccessToken)
322+
expect(results[2]?.accessToken).toBe(expectedAccessToken)
323+
324+
// Only one HTTP request should have been made
325+
expectCorrectEndpointWasHit(mockFetch, "https://www.example.com/api/v1/refresh_token", 1)
326+
})
327+
226328
test("repeated calls to getAuthenticationInfo do NOT make multiple http requests if the expiration is far in the future", async () => {
227329
const { expectedAccessToken, mockFetch } = setupMockFetchThatReturnsAccessToken()
228330
let client = createClient({ authUrl: "https://www.example.com", enableBackgroundTokenRefresh: false })
@@ -459,8 +561,8 @@ test("if a new client is created and cannot get an access token, it should trigg
459561
const post401AuthenticationInfo0 = await client0.getAuthenticationInfoOrNull()
460562
expect(post401AuthenticationInfo0).toBeNull()
461563

462-
// Called 3 times because client0 ends up making 2 requests, 1 when client1 triggers a logout event and 1 when asked
463-
expectCorrectEndpointWasHit(logoutMockFetch, "https://www.example.com/api/v1/refresh_token", 3)
564+
// Called 2 times: 1 from client1, 1 from client0 (the storage-triggered request is reused by the explicit call)
565+
expectCorrectEndpointWasHit(logoutMockFetch, "https://www.example.com/api/v1/refresh_token", 2)
464566
})
465567

466568
function expectCorrectEndpointWasHit(mockFetch: any, correctRefreshUrl: string, numSendTimes = 1, method = "get") {

0 commit comments

Comments
 (0)