diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift index cb341697043..beffcdfa61e 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift @@ -276,38 +276,39 @@ struct LiveSessionTests { await session.close() } - // Getting a limited use token adds too much of an overhead; we can't interrupt the model in time @Test( arguments: arguments.filter { !$0.0.useLimitedUseAppCheckTokens } ) + // Getting a limited use token adds too much of an overhead; we can't interrupt the model in time func realtime_interruption(_ config: InstanceConfig, modelName: String) async throws { let model = FirebaseAI.componentInstance(config).liveModel( modelName: modelName, generationConfig: audioConfig ) - let session = try await model.connect() - guard let audioFile = NSDataAsset(name: "hello") else { Issue.record("Missing audio file 'hello.wav' in Assets") return } - await session.sendAudioRealtime(audioFile.data) - await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) - // wait a second to allow the model to start generating (and cuase a proper interruption) - try await Task.sleep(nanoseconds: oneSecondInNanoseconds) - await session.sendAudioRealtime(audioFile.data) - await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) + try await retry(times: 3, delayInSeconds: 2.0) { + let session = try await model.connect() + await session.sendAudioRealtime(audioFile.data) + await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) - for try await content in session.responsesOf(LiveServerContent.self) { - if content.wasInterrupted { - break - } + // wait a second to allow the model to start generating (and cuase a proper interruption) + try await Task.sleep(nanoseconds: oneSecondInNanoseconds) + await session.sendAudioRealtime(audioFile.data) + await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) - if content.isTurnComplete { - Issue.record("The model never sent an interrupted message.") - return + for try await content in session.responsesOf(LiveServerContent.self) { + if content.wasInterrupted { + break + } + + if content.isTurnComplete { + throw NoInterruptionError() + } } } } @@ -472,6 +473,11 @@ private extension LiveSession { } } +private struct NoInterruptionError: Error, + CustomStringConvertible { + var description: String { "The model never sent an interrupted message." } +} + private extension ModelContent { /// A collection of text from all parts. /// diff --git a/FirebaseAI/Tests/TestApp/Tests/Utilities/IntegrationTestUtils.swift b/FirebaseAI/Tests/TestApp/Tests/Utilities/IntegrationTestUtils.swift index f133207540c..af1ef347c63 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Utilities/IntegrationTestUtils.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Utilities/IntegrationTestUtils.swift @@ -13,6 +13,7 @@ // limitations under the License. import Foundation +import Testing import XCTest enum IntegrationTestUtils { @@ -43,3 +44,35 @@ extension Numeric where Self: Strideable, Self.Stride.Magnitude: Comparable { return distance(to: other).magnitude <= accuracy.magnitude } } + +/// Retry a flakey test N times before failing. +/// +/// - Parameters: +/// - times: The amount of attempts to retry before failing. Must be greater than 0. +/// - delayInSeconds: How long to wait before performing the next attempt. +@discardableResult +func retry(times: Int, + delayInSeconds: TimeInterval = 0.1, + _ test: () async throws -> T) async throws -> T { + if times <= 0 { + precondition(times <= 0, "Times must be greater than 0.") + } + let delayNanos = UInt64(delayInSeconds * 1e+9) + var lastError: Error? + for attempt in 1 ... times { + do { return try await test() } + catch { + lastError = error + // only wait if we have more attempts + if attempt < times { + try? await Task.sleep(nanoseconds: delayNanos) + } + } + } + guard let lastError else { + // should not happen unless we change the above code in some way + fatalError("Internal error: retry loop finished without error") + } + Issue.record("Flaky test failed after \(times) attempt(s): \(String(describing: lastError))") + throw lastError +}