diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt index e992f92e674..96d0a751ad9 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt @@ -236,8 +236,13 @@ internal constructor( "wss://firebasevertexai.googleapis.com/ws/google.firebase.vertexai.v1beta.GenerativeService/BidiGenerateContent?key=$key" } - suspend fun getWebSocketSession(location: String): DefaultClientWebSocketSession = - client.webSocketSession(getBidiEndpoint(location)) { applyCommonHeaders() } + suspend fun getWebSocketSession(location: String): DefaultClientWebSocketSession { + val headers = resolveHeaders() + return client.webSocketSession(getBidiEndpoint(location)) { + applyCommonHeaders() + headers.forEach { (key, value) -> header(key, value) } + } + } fun generateContentStream( request: GenerateContentRequest @@ -285,16 +290,17 @@ internal constructor( } private suspend fun HttpRequestBuilder.applyHeaderProvider() { - if (headerProvider != null) { - try { - withTimeout(headerProvider.timeout) { - for ((tag, value) in headerProvider.generateHeaders()) { - header(tag, value) - } - } - } catch (e: TimeoutCancellationException) { - Log.w(TAG, "HeaderProvided timed out without generating headers, ignoring") - } + val headers = resolveHeaders() + headers.forEach { (tag, value) -> header(tag, value) } + } + + private suspend fun resolveHeaders(): Map { + if (headerProvider == null) return emptyMap() + return try { + withTimeout(headerProvider.timeout) { headerProvider.generateHeaders() } + } catch (e: TimeoutCancellationException) { + Log.w(TAG, "HeaderProvided timed out without generating headers, ignoring") + emptyMap() } }