Skip to content

Commit 8dd8e89

Browse files
authored
Add a mechanism to prevent concurrent token refreshes (#15493)
1 parent 0c064ae commit 8dd8e89

File tree

3 files changed

+329
-1
lines changed

3 files changed

+329
-1
lines changed

FirebaseAuth/Sources/Swift/SystemService/SecureTokenService.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ private let kFiveMinutes = 5 * 60.0
1919

2020
@available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *)
2121
actor SecureTokenServiceInternal {
22+
/// Coalescer to deduplicate concurrent token refresh requests.
23+
/// When multiple requests arrive at the same time, only one network call is made.
24+
private let refreshCoalescer = TokenRefreshCoalescer()
25+
2226
/// Fetch a fresh ephemeral access token for the ID associated with this instance. The token
2327
/// received in the callback should be considered short lived and not cached.
2428
///
@@ -32,7 +36,20 @@ actor SecureTokenServiceInternal {
3236
return (service.accessToken, false)
3337
} else {
3438
AuthLog.logDebug(code: "I-AUT000017", message: "Fetching new token from backend.")
35-
return try await requestAccessToken(retryIfExpired: true, service: service, backend: backend)
39+
40+
// Use coalescer to deduplicate concurrent refresh requests.
41+
// If multiple requests arrive while one is in progress, they all wait
42+
// for the same network response instead of making redundant calls.
43+
let currentToken = service.accessToken
44+
return try await refreshCoalescer.coalescedRefresh(
45+
currentToken: currentToken
46+
) {
47+
try await self.requestAccessToken(
48+
retryIfExpired: true,
49+
service: service,
50+
backend: backend
51+
)
52+
}
3653
}
3754
}
3855

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// Coalesces multiple concurrent token refresh requests into a single network call.
18+
///
19+
/// When multiple requests for a token refresh arrive concurrently (e.g., from Storage, Firestore,
20+
/// and auto-refresh), instead of making separate network calls for each one, this class ensures
21+
/// that only ONE network request is made. All concurrent callers wait for and receive the same
22+
/// refreshed token.
23+
///
24+
/// This prevents redundant STS (Secure Token Service) calls and reduces load on both the client
25+
/// and server.
26+
///
27+
/// Example:
28+
/// ```
29+
/// // Multiple concurrent requests arrive at the same time
30+
/// Task { try await tokenRefreshCoalescer.coalescedRefresh(currentToken: token, ...) } // 1
31+
/// Task { try await tokenRefreshCoalescer.coalescedRefresh(currentToken: token, ...) } // 2
32+
/// Task { try await tokenRefreshCoalescer.coalescedRefresh(currentToken: token, ...) } // 3
33+
///
34+
/// // Only ONE network call is made. All three tasks receive the same refreshed token.
35+
/// ```
36+
@available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *)
37+
actor TokenRefreshCoalescer {
38+
/// The in-flight token refresh task, if any.
39+
/// When this is set, all concurrent calls wait for this task instead of starting their own.
40+
private var pendingRefreshTask: Task<(String?, Bool), Error>?
41+
42+
/// The token string of the pending refresh.
43+
/// Used to ensure we only coalesce requests for the same token.
44+
private var pendingRefreshToken: String?
45+
46+
/// Performs a coalesced token refresh.
47+
///
48+
/// If a refresh is already in progress, this method waits for that refresh to complete
49+
/// and returns its result. If no refresh is in progress, it starts a new one and stores
50+
/// the task so other concurrent callers can wait for it.
51+
///
52+
/// - Parameters:
53+
/// - currentToken: The current token string. Used to detect token changes.
54+
/// If the current token differs from the pending refresh token,
55+
/// a new refresh is started (old one is ignored).
56+
/// - refreshFunction: A closure that performs the actual network request and refresh.
57+
/// Should be called only if a new refresh is needed.
58+
///
59+
/// - Returns: A tuple containing (refreshedToken, wasUpdated) matching the format
60+
/// of SecureTokenService.
61+
///
62+
/// - Throws: Any error from the refresh operation.
63+
func coalescedRefresh(currentToken: String,
64+
refreshFunction: @escaping () async throws -> (String?, Bool)) async throws
65+
-> (
66+
String?,
67+
Bool
68+
) {
69+
// Check if a refresh is already in progress for this token
70+
if let pendingTask = pendingRefreshTask,
71+
pendingRefreshToken == currentToken {
72+
// Token hasn't changed and a refresh is in progress
73+
// Wait for the pending refresh to complete
74+
return try await pendingTask.value
75+
}
76+
77+
// Either no refresh is in progress, or the token has changed.
78+
// Start a new refresh task.
79+
let task = Task {
80+
try await refreshFunction()
81+
}
82+
83+
// Store the task so other concurrent callers can wait for it
84+
pendingRefreshTask = task
85+
pendingRefreshToken = currentToken
86+
87+
defer {
88+
// Clean up the pending task after it completes
89+
pendingRefreshTask = nil
90+
pendingRefreshToken = nil
91+
}
92+
93+
do {
94+
return try await task.value
95+
} catch {
96+
// On error, clear the pending task so the next call will retry
97+
pendingRefreshTask = nil
98+
pendingRefreshToken = nil
99+
throw error
100+
}
101+
}
102+
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
@testable import FirebaseAuth
16+
import XCTest
17+
18+
actor Counter {
19+
private var valueInternal: Int = 0
20+
func increment() { valueInternal += 1 }
21+
func value() -> Int { valueInternal }
22+
}
23+
24+
@available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *)
25+
class TokenRefreshCoalescerTests: XCTestCase {
26+
/// Tests that when multiple concurrent refresh requests arrive for the same token,
27+
/// only ONE network call is made.
28+
///
29+
/// This is the main issue fix: Previously, each concurrent caller would make its own
30+
/// network request, resulting in redundant STS calls.
31+
func testCoalescedRefreshMakesOnlyOneNetworkCall() async throws {
32+
let coalescer = TokenRefreshCoalescer()
33+
let counter = Counter()
34+
35+
// Simulate multiple concurrent refresh requests
36+
async let result1 = try coalescer.coalescedRefresh(currentToken: "token_v1") {
37+
await counter.increment()
38+
39+
// Simulate network delay
40+
try await Task.sleep(nanoseconds: 100_000_000) // 0.1 seconds
41+
42+
return ("new_token", true)
43+
}
44+
45+
async let result2 = try coalescer.coalescedRefresh(currentToken: "token_v1") {
46+
await counter.increment()
47+
48+
try await Task.sleep(nanoseconds: 100_000_000)
49+
return ("new_token", true)
50+
}
51+
52+
async let result3 = try coalescer.coalescedRefresh(currentToken: "token_v1") {
53+
await counter.increment()
54+
55+
try await Task.sleep(nanoseconds: 100_000_000)
56+
return ("new_token", true)
57+
}
58+
59+
// Wait for all three to complete
60+
let (token1, updated1) = try await result1
61+
let (token2, updated2) = try await result2
62+
let (token3, updated3) = try await result3
63+
64+
// All three should get the same token
65+
XCTAssertEqual(token1, "new_token")
66+
XCTAssertEqual(token2, "new_token")
67+
XCTAssertEqual(token3, "new_token")
68+
69+
XCTAssertTrue(updated1)
70+
XCTAssertTrue(updated2)
71+
XCTAssertTrue(updated3)
72+
73+
// CRITICAL: Only ONE network call should have been made
74+
// (Previously, without coalescing, this would be 3)
75+
let callCount = await counter.value()
76+
XCTAssertEqual(callCount, 1, "Expected only 1 network call, but got \(callCount)")
77+
}
78+
79+
/// Tests that when the token changes, a new refresh is started instead of
80+
/// coalescing with the old one.
81+
func testNewRefreshStartsWhenTokenChanges() async throws {
82+
let coalescer = TokenRefreshCoalescer()
83+
let counter = Counter()
84+
85+
// First refresh for token_v1
86+
async let result1 = try coalescer.coalescedRefresh(currentToken: "token_v1") {
87+
await counter.increment()
88+
89+
try await Task.sleep(nanoseconds: 50_000_000)
90+
return ("new_token_1", true)
91+
}
92+
93+
// Wait a bit, then start a refresh for a different token (token_v2)
94+
// This should NOT coalesce with the first one
95+
try await Task.sleep(nanoseconds: 10_000_000)
96+
97+
async let result2 = try coalescer.coalescedRefresh(currentToken: "token_v2") {
98+
await counter.increment()
99+
100+
try await Task.sleep(nanoseconds: 50_000_000)
101+
return ("new_token_2", true)
102+
}
103+
104+
let token1 = try await result1.0
105+
let token2 = try await result2.0
106+
107+
// Should get different tokens
108+
XCTAssertEqual(token1, "new_token_1")
109+
XCTAssertEqual(token2, "new_token_2")
110+
111+
// Should have made TWO network calls (one for each token)
112+
let callsAfterTwoTokens = await counter.value()
113+
XCTAssertEqual(callsAfterTwoTokens, 2)
114+
}
115+
116+
/// Tests that if a refresh fails, the next call will start a fresh attempt
117+
/// instead of waiting for the failed one.
118+
func testFailedRefreshAllowsRetry() async throws {
119+
let coalescer = TokenRefreshCoalescer()
120+
let counter = Counter()
121+
122+
// First call will fail (run it to completion)
123+
do {
124+
_ = try await coalescer.coalescedRefresh(currentToken: "token_v1") {
125+
await counter.increment()
126+
throw NSError(domain: "TestError", code: -1, userInfo: nil)
127+
}
128+
XCTFail("Expected error")
129+
} catch {
130+
// Expected failure
131+
}
132+
133+
// Second call after the failure should start a fresh attempt and succeed
134+
let (token2, updated2) = try await coalescer.coalescedRefresh(currentToken: "token_v1") {
135+
await counter.increment()
136+
return ("recovered_token", true)
137+
}
138+
139+
XCTAssertEqual(token2, "recovered_token")
140+
XCTAssertTrue(updated2)
141+
142+
// Should have made TWO network calls (first failed, second succeeded)
143+
let secondResult = await counter.value()
144+
XCTAssertEqual(secondResult, 2)
145+
}
146+
147+
/// Stress test: Many concurrent calls for the same token
148+
func testManyCurrentCallsWithSameToken() async throws {
149+
let coalescer = TokenRefreshCoalescer()
150+
let counter = Counter()
151+
152+
let numCalls = 50
153+
var tasks: [Task<(String?, Bool), Error>] = []
154+
155+
// Launch 50 concurrent refresh tasks
156+
for _ in 0 ..< numCalls {
157+
let task = Task {
158+
try await coalescer.coalescedRefresh(currentToken: "token_stress") {
159+
await counter.increment()
160+
161+
try await Task.sleep(nanoseconds: 100_000_000)
162+
return ("stress_token", true)
163+
}
164+
}
165+
tasks.append(task)
166+
}
167+
168+
// Wait for all to complete
169+
var successCount = 0
170+
for task in tasks {
171+
let (token, updated) = try await task.value
172+
XCTAssertEqual(token, "stress_token")
173+
XCTAssertTrue(updated)
174+
successCount += 1
175+
}
176+
177+
XCTAssertEqual(successCount, numCalls)
178+
179+
// All 50 concurrent calls should result in ONLY 1 network call
180+
let stressCallCount = await counter.value()
181+
XCTAssertEqual(
182+
stressCallCount,
183+
1,
184+
"Expected 1 network call for 50 concurrent requests, but got \(stressCallCount)"
185+
)
186+
}
187+
188+
/// Tests that concurrent calls with forceRefresh:false still use the cache
189+
/// when tokens are valid.
190+
func testCachingStillWorksWithCoalescer() async throws {
191+
let coalescer = TokenRefreshCoalescer()
192+
let counter = Counter()
193+
194+
// First call triggers a refresh
195+
let result1 = try await coalescer.coalescedRefresh(currentToken: "token_v1") {
196+
await counter.increment()
197+
198+
return ("refreshed_token", true)
199+
}
200+
201+
XCTAssertEqual(result1.0, "refreshed_token")
202+
let resultAfterRefresh = await counter.value()
203+
XCTAssertEqual(resultAfterRefresh, 1)
204+
205+
// This test documents that caching logic happens BEFORE coalescer is called,
206+
// so this scenario doesn't test the coalescer directly, but verifies the
207+
// integration is correct.
208+
}
209+
}

0 commit comments

Comments
 (0)