Skip to content

Commit 7c6bbcb

Browse files
daymxnandrewheard
andcommitted
fix(ai): Add retry mechanism to flakey interrupt test (#15421)
Co-authored-by: Andrew Heard <[email protected]>
1 parent fa0f903 commit 7c6bbcb

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -276,38 +276,39 @@ struct LiveSessionTests {
276276
await session.close()
277277
}
278278

279-
// Getting a limited use token adds too much of an overhead; we can't interrupt the model in time
280279
@Test(
281280
arguments: arguments.filter { !$0.0.useLimitedUseAppCheckTokens }
282281
)
282+
// Getting a limited use token adds too much of an overhead; we can't interrupt the model in time
283283
func realtime_interruption(_ config: InstanceConfig, modelName: String) async throws {
284284
let model = FirebaseAI.componentInstance(config).liveModel(
285285
modelName: modelName,
286286
generationConfig: audioConfig
287287
)
288288

289-
let session = try await model.connect()
290-
291289
guard let audioFile = NSDataAsset(name: "hello") else {
292290
Issue.record("Missing audio file 'hello.wav' in Assets")
293291
return
294292
}
295-
await session.sendAudioRealtime(audioFile.data)
296-
await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))
297293

298-
// wait a second to allow the model to start generating (and cuase a proper interruption)
299-
try await Task.sleep(nanoseconds: oneSecondInNanoseconds)
300-
await session.sendAudioRealtime(audioFile.data)
301-
await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))
294+
try await retry(times: 3, delayInSeconds: 2.0) {
295+
let session = try await model.connect()
296+
await session.sendAudioRealtime(audioFile.data)
297+
await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))
302298

303-
for try await content in session.responsesOf(LiveServerContent.self) {
304-
if content.wasInterrupted {
305-
break
306-
}
299+
// wait a second to allow the model to start generating (and cuase a proper interruption)
300+
try await Task.sleep(nanoseconds: oneSecondInNanoseconds)
301+
await session.sendAudioRealtime(audioFile.data)
302+
await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))
307303

308-
if content.isTurnComplete {
309-
Issue.record("The model never sent an interrupted message.")
310-
return
304+
for try await content in session.responsesOf(LiveServerContent.self) {
305+
if content.wasInterrupted {
306+
break
307+
}
308+
309+
if content.isTurnComplete {
310+
throw NoInterruptionError()
311+
}
311312
}
312313
}
313314
}
@@ -472,6 +473,11 @@ private extension LiveSession {
472473
}
473474
}
474475

476+
private struct NoInterruptionError: Error,
477+
CustomStringConvertible {
478+
var description: String { "The model never sent an interrupted message." }
479+
}
480+
475481
private extension ModelContent {
476482
/// A collection of text from all parts.
477483
///

FirebaseAI/Tests/TestApp/Tests/Utilities/IntegrationTestUtils.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import Foundation
16+
import Testing
1617
import XCTest
1718

1819
enum IntegrationTestUtils {
@@ -43,3 +44,35 @@ extension Numeric where Self: Strideable, Self.Stride.Magnitude: Comparable {
4344
return distance(to: other).magnitude <= accuracy.magnitude
4445
}
4546
}
47+
48+
/// Retry a flakey test N times before failing.
49+
///
50+
/// - Parameters:
51+
/// - times: The amount of attempts to retry before failing. Must be greater than 0.
52+
/// - delayInSeconds: How long to wait before performing the next attempt.
53+
@discardableResult
54+
func retry<T>(times: Int,
55+
delayInSeconds: TimeInterval = 0.1,
56+
_ test: () async throws -> T) async throws -> T {
57+
if times <= 0 {
58+
precondition(times <= 0, "Times must be greater than 0.")
59+
}
60+
let delayNanos = UInt64(delayInSeconds * 1e+9)
61+
var lastError: Error?
62+
for attempt in 1 ... times {
63+
do { return try await test() }
64+
catch {
65+
lastError = error
66+
// only wait if we have more attempts
67+
if attempt < times {
68+
try? await Task.sleep(nanoseconds: delayNanos)
69+
}
70+
}
71+
}
72+
guard let lastError else {
73+
// should not happen unless we change the above code in some way
74+
fatalError("Internal error: retry loop finished without error")
75+
}
76+
Issue.record("Flaky test failed after \(times) attempt(s): \(String(describing: lastError))")
77+
throw lastError
78+
}

0 commit comments

Comments
 (0)