diff --git a/firebase-vertexai/CHANGELOG.md b/firebase-vertexai/CHANGELOG.md index 0969016889c..0334dc35f6b 100644 --- a/firebase-vertexai/CHANGELOG.md +++ b/firebase-vertexai/CHANGELOG.md @@ -9,6 +9,11 @@ * [feature] Added support for `HarmBlockThreshold.OFF`. See the [model documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters#how_to_configure_content_filters){: .external} for more information. +* [fixed] Improved thread usage when using a `LiveGenerativeModel`. (#6870) +* [fixed] Fixed an issue with `LiveContentResponse` audio data not being present when the model was + interrupted or the turn completed. (#6870) +* [fixed] Fixed an issue with `LiveSession` not converting exceptions to `FirebaseVertexAIException`. (#6870) + # 16.3.0 * [feature] Emits a warning when attempting to use an incompatible model with diff --git a/firebase-vertexai/api.txt b/firebase-vertexai/api.txt index 1ea9d432305..ecc567e537f 100644 --- a/firebase-vertexai/api.txt +++ b/firebase-vertexai/api.txt @@ -629,7 +629,7 @@ package com.google.firebase.vertexai.type { method public suspend Object? send(String text, kotlin.coroutines.Continuation); method public suspend Object? sendFunctionResponse(java.util.List functionList, kotlin.coroutines.Continuation); method public suspend Object? sendMediaStream(java.util.List mediaChunks, kotlin.coroutines.Continuation); - method public suspend Object? startAudioConversation(kotlin.jvm.functions.Function1? functionCallHandler = null, kotlin.coroutines.Continuation); + method @RequiresPermission(android.Manifest.permission.RECORD_AUDIO) public suspend Object? startAudioConversation(kotlin.jvm.functions.Function1? functionCallHandler = null, kotlin.coroutines.Continuation); method public void stopAudioConversation(); method public void stopReceiving(); } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt index 7c90e78c402..c36ec25d078 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt @@ -19,7 +19,7 @@ package com.google.firebase.vertexai import android.util.Log import com.google.firebase.Firebase import com.google.firebase.FirebaseApp -import com.google.firebase.annotations.concurrent.Background +import com.google.firebase.annotations.concurrent.Blocking import com.google.firebase.app import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider @@ -41,7 +41,7 @@ import kotlin.coroutines.CoroutineContext public class FirebaseVertexAI internal constructor( private val firebaseApp: FirebaseApp, - @Background private val backgroundDispatcher: CoroutineContext, + @Blocking private val blockingDispatcher: CoroutineContext, private val location: String, private val appCheckProvider: Provider, private val internalAuthProvider: Provider, @@ -133,7 +133,7 @@ internal constructor( "projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}", firebaseApp.options.apiKey, firebaseApp, - backgroundDispatcher, + blockingDispatcher, generationConfig, tools, systemInstruction, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIMultiResourceComponent.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIMultiResourceComponent.kt index 1b9cb7a4909..526e1f87be8 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIMultiResourceComponent.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIMultiResourceComponent.kt @@ -18,7 +18,7 @@ package com.google.firebase.vertexai import androidx.annotation.GuardedBy import com.google.firebase.FirebaseApp -import com.google.firebase.annotations.concurrent.Background +import com.google.firebase.annotations.concurrent.Blocking import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider import com.google.firebase.inject.Provider @@ -31,7 +31,7 @@ import kotlin.coroutines.CoroutineContext */ internal class FirebaseVertexAIMultiResourceComponent( private val app: FirebaseApp, - @Background val backgroundDispatcher: CoroutineContext, + @Blocking val blockingDispatcher: CoroutineContext, private val appCheckProvider: Provider, private val internalAuthProvider: Provider, ) { @@ -43,7 +43,7 @@ internal class FirebaseVertexAIMultiResourceComponent( instances[location] ?: FirebaseVertexAI( app, - backgroundDispatcher, + blockingDispatcher, location, appCheckProvider, internalAuthProvider diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIRegistrar.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIRegistrar.kt index ff5409567a9..13cb73cdb71 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIRegistrar.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIRegistrar.kt @@ -18,7 +18,7 @@ package com.google.firebase.vertexai import androidx.annotation.Keep import com.google.firebase.FirebaseApp -import com.google.firebase.annotations.concurrent.Background +import com.google.firebase.annotations.concurrent.Blocking import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider import com.google.firebase.components.Component @@ -41,13 +41,13 @@ internal class FirebaseVertexAIRegistrar : ComponentRegistrar { Component.builder(FirebaseVertexAIMultiResourceComponent::class.java) .name(LIBRARY_NAME) .add(Dependency.required(firebaseApp)) - .add(Dependency.required(backgroundDispatcher)) + .add(Dependency.required(blockingDispatcher)) .add(Dependency.optionalProvider(appCheckInterop)) .add(Dependency.optionalProvider(internalAuthProvider)) .factory { container -> FirebaseVertexAIMultiResourceComponent( container[firebaseApp], - container.get(backgroundDispatcher), + container.get(blockingDispatcher), container.getProvider(appCheckInterop), container.getProvider(internalAuthProvider) ) @@ -62,7 +62,7 @@ internal class FirebaseVertexAIRegistrar : ComponentRegistrar { private val firebaseApp = unqualified(FirebaseApp::class.java) private val appCheckInterop = unqualified(InteropAppCheckTokenProvider::class.java) private val internalAuthProvider = unqualified(InternalAuthProvider::class.java) - private val backgroundDispatcher = - Qualified.qualified(Background::class.java, CoroutineDispatcher::class.java) + private val blockingDispatcher = + Qualified.qualified(Blocking::class.java, CoroutineDispatcher::class.java) } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/LiveGenerativeModel.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/LiveGenerativeModel.kt index e557b694620..d546e09cdd2 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/LiveGenerativeModel.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/LiveGenerativeModel.kt @@ -17,13 +17,14 @@ package com.google.firebase.vertexai import com.google.firebase.FirebaseApp -import com.google.firebase.annotations.concurrent.Background +import com.google.firebase.annotations.concurrent.Blocking import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.AppCheckHeaderProvider -import com.google.firebase.vertexai.type.BidiGenerateContentClientMessage +import com.google.firebase.vertexai.common.JSON import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.LiveClientSetupMessage import com.google.firebase.vertexai.type.LiveGenerationConfig import com.google.firebase.vertexai.type.LiveSession import com.google.firebase.vertexai.type.PublicPreviewAPI @@ -38,6 +39,7 @@ import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject /** * Represents a multimodal model (like Gemini) capable of real-time content generation based on @@ -47,7 +49,7 @@ import kotlinx.serialization.json.Json public class LiveGenerativeModel internal constructor( private val modelName: String, - @Background private val backgroundDispatcher: CoroutineContext, + @Blocking private val blockingDispatcher: CoroutineContext, private val config: LiveGenerationConfig? = null, private val tools: List? = null, private val systemInstruction: Content? = null, @@ -58,7 +60,7 @@ internal constructor( modelName: String, apiKey: String, firebaseApp: FirebaseApp, - backgroundDispatcher: CoroutineContext, + blockingDispatcher: CoroutineContext, config: LiveGenerationConfig? = null, tools: List? = null, systemInstruction: Content? = null, @@ -68,7 +70,7 @@ internal constructor( internalAuthProvider: InternalAuthProvider? = null, ) : this( modelName, - backgroundDispatcher, + blockingDispatcher, config, tools, systemInstruction, @@ -93,7 +95,7 @@ internal constructor( @OptIn(ExperimentalSerializationApi::class) public suspend fun connect(): LiveSession { val clientMessage = - BidiGenerateContentClientMessage( + LiveClientSetupMessage( modelName, config?.toInternal(), tools?.map { it.toInternal() }, @@ -104,10 +106,11 @@ internal constructor( try { val webSession = controller.getWebSocketSession(location) webSession.send(Frame.Text(data)) - val receivedJson = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8) - // TODO: Try to decode the json instead of string matching. - return if (receivedJson.contains("setupComplete")) { - LiveSession(session = webSession, backgroundDispatcher = backgroundDispatcher) + val receivedJsonStr = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8) + val receivedJson = JSON.parseToJsonElement(receivedJsonStr) + + return if (receivedJson is JsonObject && "setupComplete" in receivedJson) { + LiveSession(session = webSession, blockingDispatcher = blockingDispatcher) } else { webSession.close() throw ServiceConnectionHandshakeFailedException("Unable to connect to the server") diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt index da580429f8c..f82d4866cf6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt @@ -165,6 +165,7 @@ internal constructor( suspend fun getWebSocketSession(location: String): ClientWebSocketSession = client.webSocketSession(getBidiEndpoint(location)) + fun generateContentStream( request: GenerateContentRequest ): Flow = diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/android.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/android.kt new file mode 100644 index 00000000000..6de0339e032 --- /dev/null +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/android.kt @@ -0,0 +1,50 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai.common.util + +import android.media.AudioRecord +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.yield + +/** + * The minimum buffer size for this instance. + * + * The same as calling [AudioRecord.getMinBufferSize], except the params are pre-populated. + */ +internal val AudioRecord.minBufferSize: Int + get() = AudioRecord.getMinBufferSize(sampleRate, channelConfiguration, audioFormat) + +/** + * Reads from this [AudioRecord] and returns the data in a flow. + * + * Will yield when this instance is not recording. + */ +internal fun AudioRecord.readAsFlow() = flow { + val buffer = ByteArray(minBufferSize) + + while (true) { + if (recordingState != AudioRecord.RECORDSTATE_RECORDING) { + yield() + continue + } + + val bytesRead = read(buffer, 0, buffer.size) + if (bytesRead > 0) { + emit(buffer.copyOf(bytesRead)) + } + } +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/kotlin.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/kotlin.kt index bf806528781..05e37e490ab 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/kotlin.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/kotlin.kt @@ -16,7 +16,16 @@ package com.google.firebase.vertexai.common.util +import java.io.ByteArrayOutputStream import java.lang.reflect.Field +import kotlin.coroutines.EmptyCoroutineContext +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancel +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.fold /** * Removes the last character from the [StringBuilder]. @@ -39,3 +48,56 @@ internal fun StringBuilder.removeLast(): StringBuilder = * ``` */ internal inline fun Field.getAnnotation() = getAnnotation(T::class.java) + +/** + * Collects bytes from this flow and doesn't emit them back until [minSize] is reached. + * + * For example: + * ``` + * val byteArr = flowOf(byteArrayOf(1), byteArrayOf(2, 3, 4), byteArrayOf(5, 6, 7, 8)) + * val expectedResult = listOf(byteArrayOf(1, 2, 3, 4), byteArrayOf( 5, 6, 7, 8)) + * + * byteArr.accumulateUntil(4).toList() shouldContainExactly expectedResult + * ``` + * + * @param minSize The minimum about of bytes the array should have before being sent down-stream + * @param emitLeftOvers If the flow completes and there are bytes left over that don't meet the + * [minSize], send them anyways. + */ +internal fun Flow.accumulateUntil( + minSize: Int, + emitLeftOvers: Boolean = false +): Flow = flow { + val remaining = + fold(ByteArrayOutputStream()) { buffer, it -> + buffer.apply { + write(it, 0, it.size) + if (size() >= minSize) { + emit(toByteArray()) + reset() + } + } + } + + if (emitLeftOvers && remaining.size() > 0) { + emit(remaining.toByteArray()) + } +} + +/** + * Create a [Job] that is a child of the [currentCoroutineContext], if any. + * + * This is useful when you want a coroutine scope to be canceled when its parent scope is canceled, + * and you don't have full control over the parent scope, but you don't want the cancellation of the + * child to impact the parent. + * + * If the parent coroutine context does not have a job, an empty one will be created. + */ +internal suspend inline fun childJob() = Job(currentCoroutineContext()[Job] ?: Job()) + +/** + * A constant value pointing to a cancelled [CoroutineScope]. + * + * Useful when you want to initialize a mutable [CoroutineScope] in a canceled state. + */ +internal val CancelledCoroutineScope = CoroutineScope(EmptyCoroutineContext).apply { cancel() } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/java/LiveSessionFutures.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/java/LiveSessionFutures.kt index f5b56758d19..169f9723ad8 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/java/LiveSessionFutures.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/java/LiveSessionFutures.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.java +import android.Manifest.permission.RECORD_AUDIO +import androidx.annotation.RequiresPermission import androidx.concurrent.futures.SuspendToFutureAdapter import com.google.common.util.concurrent.ListenableFuture import com.google.firebase.vertexai.type.Content @@ -26,6 +28,7 @@ import com.google.firebase.vertexai.type.LiveSession import com.google.firebase.vertexai.type.MediaData import com.google.firebase.vertexai.type.PublicPreviewAPI import com.google.firebase.vertexai.type.SessionAlreadyReceivingException +import io.ktor.websocket.close import kotlinx.coroutines.reactive.asPublisher import org.reactivestreams.Publisher @@ -38,35 +41,50 @@ import org.reactivestreams.Publisher public abstract class LiveSessionFutures internal constructor() { /** - * Starts an audio conversation with the Gemini server, which can only be stopped using - * [stopAudioConversation]. + * Starts an audio conversation with the model, which can only be stopped using + * [stopAudioConversation] or [close]. * - * @param functionCallHandler A callback function to map function calls from the server to their - * response parts. + * @param functionCallHandler A callback function that is invoked whenever the model receives a + * function call. */ public abstract fun startAudioConversation( functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? ): ListenableFuture /** - * Starts an audio conversation with the Gemini server, which can only be stopped using + * Starts an audio conversation with the model, which can only be stopped using * [stopAudioConversation]. */ + @RequiresPermission(RECORD_AUDIO) public abstract fun startAudioConversation(): ListenableFuture /** * Stops the audio conversation with the Gemini Server. * - * @see [startAudioConversation] - * @see [stopReceiving] + * This only needs to be called after a previous call to [startAudioConversation]. + * + * If there is no audio conversation currently active, this function does nothing. */ + @RequiresPermission(RECORD_AUDIO) public abstract fun stopAudioConversation(): ListenableFuture - /** Stop receiving from the server. */ + /** + * Stops receiving from the model. + * + * If this function is called during an ongoing audio conversation, the model's response will not + * be received, and no audio will be played; the live session object will no longer receive data + * from the server. + * + * To resume receiving data, you must either handle it directly using [receive], or indirectly by + * using [startAudioConversation]. + * + * @see close + */ + // TODO(b/410059569): Remove when fixed public abstract fun stopReceiving() /** - * Sends the function response from the client to the server. + * Sends function calling responses to the model. * * @param functionList The list of [FunctionResponsePart] instances indicating the function * response from the client. @@ -76,35 +94,51 @@ public abstract class LiveSessionFutures internal constructor() { ): ListenableFuture /** - * Streams client data to the server. + * Streams client data to the model. + * + * Calling this after [startAudioConversation] will play the response audio immediately. * * @param mediaChunks The list of [MediaData] instances representing the media data to be sent. */ public abstract fun sendMediaStream(mediaChunks: List): ListenableFuture /** - * Sends [data][Content] to the server. + * Sends [data][Content] to the model. * - * @param content Client [Content] to be sent to the server. + * Calling this after [startAudioConversation] will play the response audio immediately. + * + * @param content Client [Content] to be sent to the model. */ public abstract fun send(content: Content): ListenableFuture /** - * Sends text to the server + * Sends text to the model. + * + * Calling this after [startAudioConversation] will play the response audio immediately. * - * @param text Text to be sent to the server. + * @param text Text to be sent to the model. */ public abstract fun send(text: String): ListenableFuture - /** Closes the client session. */ + /** + * Closes the client session. + * + * Once a [LiveSession] is closed, it can not be reopened; you'll need to start a new + * [LiveSession]. + * + * @see stopReceiving + */ public abstract fun close(): ListenableFuture /** - * Receives responses from the server for both streaming and standard requests. + * Receives responses from the model for both streaming and standard requests. + * + * Call [close] to stop receiving responses from the model. * - * @return A [Publisher] which will emit [LiveContentResponse] as and when it receives it. + * @return A [Publisher] which will emit [LiveContentResponse] from the model. * - * @throws [SessionAlreadyReceivingException] When the session is already receiving. + * @throws [SessionAlreadyReceivingException] when the session is already receiving. + * @see stopReceiving */ public abstract fun receive(): Publisher @@ -126,10 +160,12 @@ public abstract class LiveSessionFutures internal constructor() { override fun sendMediaStream(mediaChunks: List) = SuspendToFutureAdapter.launchFuture { session.sendMediaStream(mediaChunks) } + @RequiresPermission(RECORD_AUDIO) override fun startAudioConversation( functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? ) = SuspendToFutureAdapter.launchFuture { session.startAudioConversation(functionCallHandler) } + @RequiresPermission(RECORD_AUDIO) override fun startAudioConversation() = SuspendToFutureAdapter.launchFuture { session.startAudioConversation() } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/AudioHelper.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/AudioHelper.kt index 07219617cee..e74b766d9ce 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/AudioHelper.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/AudioHelper.kt @@ -17,125 +17,197 @@ package com.google.firebase.vertexai.type import android.Manifest +import android.media.AudioAttributes import android.media.AudioFormat import android.media.AudioManager import android.media.AudioRecord import android.media.AudioTrack import android.media.MediaRecorder import android.media.audiofx.AcousticEchoCanceler +import android.util.Log import androidx.annotation.RequiresPermission +import com.google.firebase.vertexai.common.util.readAsFlow import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.emptyFlow +/** + * Helper class for recording audio and playing back a separate audio track at the same time. + * + * @see AudioHelper.build + * @see LiveSession.startAudioConversation + */ @PublicPreviewAPI -internal class AudioHelper { +internal class AudioHelper( + /** Record for recording the System microphone. */ + private val recorder: AudioRecord, + /** Track for playing back what the model says. */ + private val playbackTrack: AudioTrack, +) { + private var released: Boolean = false - private lateinit var audioRecord: AudioRecord - private lateinit var audioTrack: AudioTrack - private var stopRecording: Boolean = false + /** + * Release the system resources on the recorder and playback track. + * + * Once an [AudioHelper] has been "released", it can _not_ be used again. + * + * This method can safely be called multiple times, as it won't do anything if this instance has + * already been released. + */ + fun release() { + if (released) return + released = true - internal fun release() { - stopRecording = true - if (::audioRecord.isInitialized) { - audioRecord.stop() - audioRecord.release() - } - if (::audioTrack.isInitialized) { - audioTrack.stop() - audioTrack.release() - } + recorder.release() + playbackTrack.release() } - internal fun setupAudioTrack() { - audioTrack = - AudioTrack( - AudioManager.STREAM_MUSIC, - 24000, - AudioFormat.CHANNEL_OUT_MONO, - AudioFormat.ENCODING_PCM_16BIT, - AudioTrack.getMinBufferSize( - 24000, - AudioFormat.CHANNEL_OUT_MONO, - AudioFormat.ENCODING_PCM_16BIT - ), - AudioTrack.MODE_STREAM + /** + * Play the provided audio data on the playback track. + * + * Does nothing if this [AudioHelper] has been [released][release]. + * + * @throws IllegalStateException If the playback track was not properly initialized. + * @throws IllegalArgumentException If the playback data is invalid. + * @throws RuntimeException If we fail to play the audio data for some unknown reason. + */ + fun playAudio(data: ByteArray) { + if (released) return + if (data.isEmpty()) return + + if (playbackTrack.playState == AudioTrack.PLAYSTATE_STOPPED) playbackTrack.play() + + val result = playbackTrack.write(data, 0, data.size) + if (result > 0) return + if (result == 0) { + Log.w( + TAG, + "Failed to write any audio bytes to the playback track. The audio track may have been stopped or paused." ) - audioTrack.play() - } + return + } - internal fun playAudio(data: ByteArray) { - if (!stopRecording) { - audioTrack.write(data, 0, data.size) + // ERROR_INVALID_OPERATION and ERROR_BAD_VALUE should never occur + when (result) { + AudioTrack.ERROR_INVALID_OPERATION -> + throw IllegalStateException("The playback track was not properly initialized.") + AudioTrack.ERROR_BAD_VALUE -> + throw IllegalArgumentException("Playback data is somehow invalid.") + AudioTrack.ERROR_DEAD_OBJECT -> { + Log.w(TAG, "Attempted to playback some audio, but the track has been released.") + release() // to ensure `released` is set and `record` is released too + } + AudioTrack.ERROR -> + throw RuntimeException("Failed to play the audio data for some unknown reason.") } } - fun stopRecording() { - if ( - ::audioRecord.isInitialized && audioRecord.recordingState == AudioRecord.RECORDSTATE_RECORDING - ) { - audioRecord.stop() + /** + * Pause the recording of the microphone, if it's recording. + * + * Does nothing if this [AudioHelper] has been [released][release]. + * + * @see resumeRecording + * + * @throws IllegalStateException If the playback track was not properly initialized. + */ + fun pauseRecording() { + if (released || recorder.recordingState == AudioRecord.RECORDSTATE_STOPPED) return + + try { + recorder.stop() + } catch (e: IllegalStateException) { + release() + throw IllegalStateException("The playback track was not properly initialized.") } } - fun start() { - if ( - ::audioRecord.isInitialized && audioRecord.recordingState != AudioRecord.RECORDSTATE_RECORDING - ) { - audioRecord.startRecording() - } + /** + * Resumes the recording of the microphone, if it's not already running. + * + * Does nothing if this [AudioHelper] has been [released][release]. + * + * @see pauseRecording + */ + fun resumeRecording() { + if (released || recorder.recordingState == AudioRecord.RECORDSTATE_RECORDING) return + + recorder.startRecording() + } + + /** + * Start perpetually recording the system microphone, and return the bytes read in a flow. + * + * Returns an empty flow if this [AudioHelper] has been [released][release]. + */ + fun listenToRecording(): Flow { + if (released) return emptyFlow() + + resumeRecording() + + return recorder.readAsFlow() } - @RequiresPermission(Manifest.permission.RECORD_AUDIO) - fun startRecording(): Flow { - - val bufferSize = - AudioRecord.getMinBufferSize( - 16000, - AudioFormat.CHANNEL_IN_MONO, - AudioFormat.ENCODING_PCM_16BIT - ) - if ( - bufferSize == AudioRecord.ERROR || - bufferSize == AudioRecord.ERROR_BAD_VALUE || - bufferSize <= 0 - ) { - throw AudioRecordInitializationFailedException( - "Audio Record buffer size is invalid (${bufferSize})" - ) - } - audioRecord = - AudioRecord( - MediaRecorder.AudioSource.VOICE_COMMUNICATION, - 16000, - AudioFormat.CHANNEL_IN_MONO, - AudioFormat.ENCODING_PCM_16BIT, - bufferSize - ) - if (audioRecord.state != AudioRecord.STATE_INITIALIZED) { - throw AudioRecordInitializationFailedException( - "Audio Record initialization has failed. State: ${audioRecord.state}" - ) - } - if (AcousticEchoCanceler.isAvailable()) { - val echoCanceler = AcousticEchoCanceler.create(audioRecord.audioSessionId) - echoCanceler?.enabled = true - } - audioRecord.startRecording() - - return flow { - val buffer = ByteArray(bufferSize) - while (!stopRecording) { - if (audioRecord.recordingState != AudioRecord.RECORDSTATE_RECORDING) { - buffer.fill(0x00) - continue - } - try { - val bytesRead = audioRecord.read(buffer, 0, buffer.size) - if (bytesRead > 0) { - emit(buffer.copyOf(bytesRead)) - } - } catch (_: Exception) {} + companion object { + private val TAG = AudioHelper::class.simpleName + + /** + * Creates an instance of [AudioHelper] with the track and record initialized. + * + * A separate build method is necessary so that we can properly propagate the required manifest + * permission, and throw exceptions when needed. + * + * It also makes it easier to read, since the long initialization is separate from the + * constructor. + */ + @RequiresPermission(Manifest.permission.RECORD_AUDIO) + fun build(): AudioHelper { + val playbackTrack = + AudioTrack( + AudioAttributes.Builder().setUsage(AudioAttributes.USAGE_VOICE_COMMUNICATION).build(), + AudioFormat.Builder() + .setSampleRate(24000) + .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) + .setEncoding(AudioFormat.ENCODING_PCM_16BIT) + .build(), + AudioTrack.getMinBufferSize( + 24000, + AudioFormat.CHANNEL_OUT_MONO, + AudioFormat.ENCODING_PCM_16BIT + ), + AudioTrack.MODE_STREAM, + AudioManager.AUDIO_SESSION_ID_GENERATE + ) + + val bufferSize = + AudioRecord.getMinBufferSize( + 16000, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT + ) + + if (bufferSize <= 0) + throw AudioRecordInitializationFailedException( + "Audio Record buffer size is invalid ($bufferSize)" + ) + + val recorder = + AudioRecord( + MediaRecorder.AudioSource.VOICE_COMMUNICATION, + 16000, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + bufferSize + ) + if (recorder.state != AudioRecord.STATE_INITIALIZED) + throw AudioRecordInitializationFailedException( + "Audio Record initialization has failed. State: ${recorder.state}" + ) + + if (AcousticEchoCanceler.isAvailable()) { + AcousticEchoCanceler.create(recorder.audioSessionId)?.enabled = true } + + return AudioHelper(recorder, playbackTrack) } } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt index f3256bf4c15..45e9ef027a6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt @@ -67,6 +67,38 @@ internal constructor(message: String, cause: Throwable? = null) : RuntimeExcepti RequestTimeoutException("The request failed to complete in the allotted time.") else -> UnknownException("Something unexpected happened.", cause) } + + /** + * Catch any exception thrown in the [callback] block and rethrow it as a + * [FirebaseVertexAIException]. + * + * Will return whatever the [callback] returns as well. + * + * @see catch + */ + internal suspend fun catchAsync(callback: suspend () -> T): T { + try { + return callback() + } catch (e: Exception) { + throw from(e) + } + } + + /** + * Catch any exception thrown in the [callback] block and rethrow it as a + * [FirebaseVertexAIException]. + * + * Will return whatever the [callback] returns as well. + * + * @see catchAsync + */ + internal fun catch(callback: () -> T): T { + try { + return callback() + } catch (e: Exception) { + throw from(e) + } + } } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/BidiGenerateContentClientMessage.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveClientSetupMessage.kt similarity index 71% rename from firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/BidiGenerateContentClientMessage.kt rename to firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveClientSetupMessage.kt index 5488cb240f5..6b751961ed2 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/BidiGenerateContentClientMessage.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveClientSetupMessage.kt @@ -19,19 +19,25 @@ package com.google.firebase.vertexai.type import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable +/** + * First message in a live session. + * + * Contains configuration that will be used for the duration of the session. + */ @OptIn(ExperimentalSerializationApi::class) @PublicPreviewAPI -internal class BidiGenerateContentClientMessage( +internal class LiveClientSetupMessage( val model: String, + // Some config options are supported in generateContent but not in bidi and vise versa; so bidi + // needs its own config class val generationConfig: LiveGenerationConfig.Internal?, val tools: List?, val systemInstruction: Content.Internal? ) { - @Serializable - internal class Internal(val setup: BidiGenerateContentSetup) { + internal class Internal(val setup: LiveClientSetup) { @Serializable - internal data class BidiGenerateContentSetup( + internal data class LiveClientSetup( val model: String, val generationConfig: LiveGenerationConfig.Internal?, val tools: List?, @@ -40,5 +46,5 @@ internal class BidiGenerateContentClientMessage( } fun toInternal() = - Internal(Internal.BidiGenerateContentSetup(model, generationConfig, tools, systemInstruction)) + Internal(Internal.LiveClientSetup(model, generationConfig, tools, systemInstruction)) } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveGenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveGenerationConfig.kt index 2f8f39a3afc..36879ff7cfd 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveGenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveGenerationConfig.kt @@ -20,7 +20,7 @@ import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable /** - * Configuration parameters to use for content generation. + * Configuration parameters to use for live content generation. * * @property temperature A parameter controlling the degree of randomness in token selection. A * temperature of 0 means that the highest probability tokens are always selected. In this case, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveSession.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveSession.kt index 7e14a80dfcc..30bd92c6043 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveSession.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/LiveSession.kt @@ -16,359 +16,505 @@ package com.google.firebase.vertexai.type +import android.Manifest.permission.RECORD_AUDIO import android.media.AudioFormat import android.media.AudioTrack import android.util.Log -import com.google.firebase.annotations.concurrent.Background -import com.google.firebase.vertexai.LiveGenerativeModel +import androidx.annotation.RequiresPermission +import com.google.firebase.annotations.concurrent.Blocking +import com.google.firebase.vertexai.common.JSON +import com.google.firebase.vertexai.common.util.CancelledCoroutineScope +import com.google.firebase.vertexai.common.util.accumulateUntil +import com.google.firebase.vertexai.common.util.childJob import io.ktor.client.plugins.websocket.ClientWebSocketSession import io.ktor.websocket.Frame import io.ktor.websocket.close import io.ktor.websocket.readBytes -import java.io.ByteArrayOutputStream import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicBoolean import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.cancel -import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.buffer +import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.transform +import kotlinx.coroutines.isActive import kotlinx.coroutines.launch +import kotlinx.coroutines.yield import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.decodeFromJsonElement /** Represents a live WebSocket session capable of streaming content to and from the server. */ @PublicPreviewAPI @OptIn(ExperimentalSerializationApi::class) public class LiveSession internal constructor( - private val session: ClientWebSocketSession?, - @Background private val backgroundDispatcher: CoroutineContext, + private val session: ClientWebSocketSession, + @Blocking private val blockingDispatcher: CoroutineContext, private var audioHelper: AudioHelper? = null ) { + /** + * Coroutine scope that we batch data on for [startAudioConversation]. + * + * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope. + */ + private var scope = CancelledCoroutineScope - private val audioQueue = ConcurrentLinkedQueue() + /** + * Playback audio data sent from the model. + * + * Effectively, this is what the model is saying. + */ private val playBackQueue = ConcurrentLinkedQueue() - private var startedReceiving = false - private var receiveChannel: Channel = Channel() - private var isRecording: Boolean = false - private companion object { - val TAG = LiveSession::class.java.simpleName - val MIN_BUFFER_SIZE = - AudioTrack.getMinBufferSize( - 24000, - AudioFormat.CHANNEL_OUT_MONO, - AudioFormat.ENCODING_PCM_16BIT - ) - } - - internal class ClientContentSetup(val turns: List, val turnComplete: Boolean) { - @Serializable - internal class Internal(@SerialName("client_content") val clientContent: ClientContent) { - @Serializable - internal data class ClientContent( - val turns: List, - @SerialName("turn_complete") val turnComplete: Boolean - ) - } - - fun toInternal() = Internal(Internal.ClientContent(turns, turnComplete)) - } + /** + * Toggled whenever [receive] and [stopReceiving] are called. + * + * Used to ensure only one flow is consuming the playback at once. + */ + private val startedReceiving = AtomicBoolean(false) - @OptIn(ExperimentalSerializationApi::class) - internal class ToolResponseSetup( - val functionResponses: List + /** + * Starts an audio conversation with the model, which can only be stopped using + * [stopAudioConversation] or [close]. + * + * @param functionCallHandler A callback function that is invoked whenever the model receives a + * function call. The [FunctionResponsePart] that the callback function returns will be + * automatically sent to the model. + */ + @RequiresPermission(RECORD_AUDIO) + public suspend fun startAudioConversation( + functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null ) { + FirebaseVertexAIException.catchAsync { + if (scope.isActive) { + Log.w( + TAG, + "startAudioConversation called after the recording has already started. " + + "Call stopAudioConversation to close the previous connection." + ) + return@catchAsync + } - @Serializable - internal data class Internal(val toolResponse: ToolResponse) { - @Serializable - internal data class ToolResponse( - val functionResponses: List - ) - } + scope = CoroutineScope(blockingDispatcher + childJob()) + audioHelper = AudioHelper.build() - fun toInternal() = Internal(Internal.ToolResponse(functionResponses)) + recordUserAudio() + processModelResponses(functionCallHandler) + listenForModelPlayback() + } } - internal class ServerContentSetup(val modelTurn: Content.Internal) { - @Serializable - internal class Internal(@SerialName("serverContent") val serverContent: ServerContent) { - @Serializable - internal data class ServerContent(@SerialName("modelTurn") val modelTurn: Content.Internal) - } + /** + * Stops the audio conversation with the model. + * + * This only needs to be called after a previous call to [startAudioConversation]. + * + * If there is no audio conversation currently active, this function does nothing. + */ + public fun stopAudioConversation() { + FirebaseVertexAIException.catch { + if (!startedReceiving.getAndSet(false)) return@catch - fun toInternal() = Internal(Internal.ServerContent(modelTurn)) - } + scope.cancel() + playBackQueue.clear() - internal class MediaStreamingSetup(val mediaChunks: List) { - @Serializable - internal class Internal(val realtimeInput: MediaChunks) { - @Serializable internal data class MediaChunks(val mediaChunks: List) + audioHelper?.release() + audioHelper = null } - fun toInternal() = Internal(Internal.MediaChunks(mediaChunks)) } - internal data class ToolCallSetup( - val functionCalls: List - ) { - - @Serializable - internal class Internal(val toolCall: ToolCall) { + /** + * Receives responses from the model for both streaming and standard requests. + * + * Call [close] to stop receiving responses from the model. + * + * @return A [Flow] which will emit [LiveContentResponse] from the model. + * + * @throws [SessionAlreadyReceivingException] when the session is already receiving. + * @see stopReceiving + */ + public fun receive(): Flow { + return FirebaseVertexAIException.catch { + if (startedReceiving.getAndSet(true)) { + throw SessionAlreadyReceivingException() + } - @Serializable - internal data class ToolCall(val functionCalls: List) - } + // TODO(b/410059569): Remove when fixed + flow { + while (true) { + val response = session.incoming.tryReceive() + if (response.isClosed || !startedReceiving.get()) break - fun toInternal(): Internal { - return Internal(Internal.ToolCall(functionCalls)) - } - } + val frame = response.getOrNull() + frame?.let { frameToLiveContentResponse(it) }?.let { emit(it) } - private fun fillRecordedAudioQueue() { - CoroutineScope(backgroundDispatcher).launch { - audioHelper!!.startRecording().collect { - if (!isRecording) { - cancel() + yield() + } } - audioQueue.add(it) - } + .onCompletion { stopAudioConversation() } + .catch { throw FirebaseVertexAIException.from(it) } + + // TODO(b/410059569): Add back when fixed + // return session.incoming.receiveAsFlow().transform { frame -> + // val response = frameToLiveContentResponse(frame) + // response?.let { emit(it) } + // }.onCompletion { + // stopAudioConversation() + // }.catch { throw FirebaseVertexAIException.from(it) } } } - private suspend fun sendAudioDataToServer() { + /** + * Stops receiving from the model. + * + * If this function is called during an ongoing audio conversation, the model's response will not + * be received, and no audio will be played; the live session object will no longer receive data + * from the server. + * + * To resume receiving data, you must either handle it directly using [receive], or indirectly by + * using [startAudioConversation]. + * + * @see close + */ + // TODO(b/410059569): Remove when fixed + public fun stopReceiving() { + FirebaseVertexAIException.catch { + if (!startedReceiving.getAndSet(false)) return@catch - val audioBufferStream = ByteArrayOutputStream() - while (isRecording) { - val receivedAudio = audioQueue.poll() ?: continue - audioBufferStream.write(receivedAudio) - if (audioBufferStream.size() >= MIN_BUFFER_SIZE) { - sendMediaStream(listOf(MediaData(audioBufferStream.toByteArray(), "audio/pcm"))) - audioBufferStream.reset() - } - } - } + scope.cancel() + playBackQueue.clear() - private fun fillServerResponseAudioQueue( - functionCallsHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null - ) { - CoroutineScope(backgroundDispatcher).launch { - receive().collect { - if (!isRecording) { - cancel() - } - when (it.status) { - LiveContentResponse.Status.INTERRUPTED -> - while (!playBackQueue.isEmpty()) playBackQueue.poll() - LiveContentResponse.Status.NORMAL -> - if (!it.functionCalls.isNullOrEmpty() && functionCallsHandler != null) { - sendFunctionResponse(it.functionCalls.map(functionCallsHandler).toList()) - } else { - val audioData = it.data?.parts?.get(0)?.asInlineDataPartOrNull()?.inlineData - if (audioData != null) { - playBackQueue.add(audioData) - } - } - } - } + audioHelper?.release() + audioHelper = null } } - private fun playServerResponseAudio() { - CoroutineScope(backgroundDispatcher).launch { - while (isRecording) { - val data = playBackQueue.poll() - if (data == null) { - audioHelper?.start() - continue - } - audioHelper?.stopRecording() - audioHelper?.playAudio(data) - } + /** + * Sends function calling responses to the model. + * + * **NOTE:** If you're using [startAudioConversation], the method will handle sending function + * responses to the model for you. You do _not_ need to call this method in that case. + * + * @param functionList The list of [FunctionResponsePart] instances indicating the function + * response from the client. + */ + public suspend fun sendFunctionResponse(functionList: List) { + FirebaseVertexAIException.catchAsync { + val jsonString = + Json.encodeToString( + BidiGenerateContentToolResponseSetup(functionList.map { it.toInternalFunctionCall() }) + .toInternal() + ) + session.send(Frame.Text(jsonString)) } } /** - * Starts an audio conversation with the Gemini server, which can only be stopped using - * [stopAudioConversation]. + * Streams client data to the model. * - * @param functionCallHandler A callback function that is invoked whenever the server receives a - * function call. + * Calling this after [startAudioConversation] will play the response audio immediately. + * + * @param mediaChunks The list of [MediaData] instances representing the media data to be sent. */ - public suspend fun startAudioConversation( - functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null + public suspend fun sendMediaStream( + mediaChunks: List, ) { - if (isRecording) { - Log.w(TAG, "startAudioConversation called after the recording has already started.") - return + FirebaseVertexAIException.catchAsync { + val jsonString = + Json.encodeToString( + BidiGenerateContentRealtimeInputSetup(mediaChunks.map { (it.toInternal()) }).toInternal() + ) + session.send(Frame.Text(jsonString)) } - isRecording = true - audioHelper = AudioHelper() - audioHelper!!.setupAudioTrack() - fillRecordedAudioQueue() - CoroutineScope(backgroundDispatcher).launch { sendAudioDataToServer() } - fillServerResponseAudioQueue(functionCallHandler) - playServerResponseAudio() } /** - * Stops the audio conversation with the Gemini Server. This needs to be called only after calling - * [startAudioConversation] + * Sends [data][Content] to the model. + * + * Calling this after [startAudioConversation] will play the response audio immediately. + * + * @param content Client [Content] to be sent to the model. */ - public fun stopAudioConversation() { - stopReceiving() - isRecording = false - audioHelper?.let { - while (playBackQueue.isNotEmpty()) playBackQueue.poll() - while (audioQueue.isNotEmpty()) audioQueue.poll() - it.release() + public suspend fun send(content: Content) { + FirebaseVertexAIException.catchAsync { + val jsonString = + Json.encodeToString( + BidiGenerateContentClientContentSetup(listOf(content.toInternal()), true).toInternal() + ) + session.send(Frame.Text(jsonString)) } - audioHelper = null } /** - * Stops receiving from the model. + * Sends text to the model. * - * If this function is called during an ongoing audio conversation, the model's response will not - * be received, and no audio will be played; the live session object will no longer receive data - * from the server. + * Calling this after [startAudioConversation] will play the response audio immediately. * - * To resume receiving data, you must either handle it directly using [receive], or indirectly by - * using [startAudioConversation]. + * @param text Text to be sent to the model. */ - public fun stopReceiving() { - if (!startedReceiving) { - return - } - receiveChannel.cancel() - receiveChannel = Channel() - startedReceiving = false + public suspend fun send(text: String) { + FirebaseVertexAIException.catchAsync { send(Content.Builder().text(text).build()) } } /** - * Receives responses from the server for both streaming and standard requests. Call - * [stopReceiving] to stop receiving responses from the server. + * Closes the client session. * - * @return A [Flow] which will emit [LiveContentResponse] as and when it receives it + * Once a [LiveSession] is closed, it can not be reopened; you'll need to start a new + * [LiveSession]. * - * @throws [SessionAlreadyReceivingException] when the session is already receiving. + * @see stopReceiving */ - public fun receive(): Flow { - if (startedReceiving) { - throw SessionAlreadyReceivingException() + public suspend fun close() { + FirebaseVertexAIException.catchAsync { + session.close() + stopAudioConversation() } + } - val flowReceive = session!!.incoming.receiveAsFlow() - CoroutineScope(backgroundDispatcher).launch { flowReceive.collect { receiveChannel.send(it) } } - return flow { - startedReceiving = true - while (true) { - val message = receiveChannel.receive() - val receivedBytes = (message as Frame.Binary).readBytes() - val receivedJson = receivedBytes.toString(Charsets.UTF_8) - if (receivedJson.contains("interrupted")) { - emit(LiveContentResponse(null, LiveContentResponse.Status.INTERRUPTED, null)) - continue - } - if (receivedJson.contains("turnComplete")) { - emit(LiveContentResponse(null, LiveContentResponse.Status.TURN_COMPLETE, null)) - continue + /** Listen to the user's microphone and send the data to the model. */ + private fun recordUserAudio() { + // Buffer the recording so we can keep recording while data is sent to the server + audioHelper + ?.listenToRecording() + ?.buffer(UNLIMITED) + ?.accumulateUntil(MIN_BUFFER_SIZE) + ?.onEach { sendMediaStream(listOf(MediaData(it, "audio/pcm"))) } + ?.catch { throw FirebaseVertexAIException.from(it) } + ?.launchIn(scope) + } + + /** + * Processes responses from the model during an audio conversation. + * + * Audio messages are added to [playBackQueue]. + * + * Launched asynchronously on [scope]. + * + * @param functionCallHandler A callback function that is invoked whenever the server receives a + * function call. + */ + private fun processModelResponses( + functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? + ) { + receive() + .transform { + if (it.status == LiveContentResponse.Status.INTERRUPTED) { + playBackQueue.clear() + } else { + emit(it) } - try { - val serverContent = Json.decodeFromString(receivedJson) - val data = serverContent.serverContent.modelTurn.toPublic() - if (data.parts[0].asInlineDataPartOrNull()?.mimeType?.equals("audio/pcm") == true) { - emit(LiveContentResponse(data, LiveContentResponse.Status.NORMAL, null)) - } - if (data.parts[0] is TextPart) { - emit(LiveContentResponse(data, LiveContentResponse.Status.NORMAL, null)) + } + .onEach { + if (!it.functionCalls.isNullOrEmpty()) { + if (functionCallHandler != null) { + // It's fine to suspend here since you can't have a function call running concurrently + // with an audio response + sendFunctionResponse(it.functionCalls.map(functionCallHandler).toList()) + } else { + Log.w( + TAG, + "Function calls were present in the response, but a functionCallHandler was not provided." + ) } - continue - } catch (e: Exception) { - Log.i(TAG, "Failed to decode server content: ${e.message}") } - try { - val functionContent = Json.decodeFromString(receivedJson) - emit( - LiveContentResponse( - null, - LiveContentResponse.Status.NORMAL, - functionContent.toolCall.functionCalls.map { - FunctionCallPart(it.name, it.args.orEmpty().mapValues { x -> x.value ?: JsonNull }) - } - ) - ) - continue - } catch (e: Exception) { - Log.w(TAG, "Failed to decode function calling: ${e.message}") + + val audioParts = it.data?.parts?.filterIsInstance().orEmpty() + for (part in audioParts) { + playBackQueue.add(part.inlineData) + } + } + .launchIn(scope) + } + + /** + * Listens for playback data from the model and plays the audio. + * + * Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received. + * + * Launched asynchronously on [scope]. + */ + private fun listenForModelPlayback() { + scope.launch { + while (isActive) { + val playbackData = playBackQueue.poll() + if (playbackData == null) { + // The model playback queue is complete, so we can continue recording + // TODO(b/408223520): Conditionally resume when param is added + audioHelper?.resumeRecording() + yield() + } else { + /** + * We pause the recording while the model is speaking to avoid interrupting it because of + * no echo cancellation + */ + // TODO(b/408223520): Conditionally pause when param is added + audioHelper?.pauseRecording() + + audioHelper?.playAudio(playbackData) } } } } /** - * Sends the function calling responses to the server. + * Converts a [Frame] from the model to a valid [LiveContentResponse], if possible. * - * @param functionList The list of [FunctionResponsePart] instances indicating the function - * response from the client. + * @return The corresponding [LiveContentResponse] or null if it couldn't be converted. */ - public suspend fun sendFunctionResponse(functionList: List) { - val jsonString = - Json.encodeToString( - ToolResponseSetup(functionList.map { it.toInternalFunctionCall() }).toInternal() - ) - session?.send(Frame.Text(jsonString)) + private fun frameToLiveContentResponse(frame: Frame): LiveContentResponse? { + val jsonMessage = Json.parseToJsonElement(frame.readBytes().toString(Charsets.UTF_8)) + + if (jsonMessage !is JsonObject) { + Log.w(TAG, "Server response was not a JsonObject: $jsonMessage") + return null + } + + return when { + "toolCall" in jsonMessage -> { + val functionContent = + JSON.decodeFromJsonElement(jsonMessage) + LiveContentResponse( + null, + LiveContentResponse.Status.NORMAL, + functionContent.toolCall.functionCalls.map { + FunctionCallPart(it.name, it.args.orEmpty().mapValues { x -> x.value ?: JsonNull }) + } + ) + } + "serverContent" in jsonMessage -> { + val serverContent = + JSON.decodeFromJsonElement(jsonMessage) + .serverContent + val status = + when { + serverContent.turnComplete == true -> LiveContentResponse.Status.TURN_COMPLETE + serverContent.interrupted == true -> LiveContentResponse.Status.INTERRUPTED + else -> LiveContentResponse.Status.NORMAL + } + LiveContentResponse(serverContent.modelTurn?.toPublic(), status, null) + } + else -> { + Log.w(TAG, "Failed to decode the server response: $jsonMessage") + null + } + } } /** - * Streams client data to the server. Calling this after [startAudioConversation] will play the - * response audio immediately. + * Incremental update of the current conversation delivered from the client. * - * @param mediaChunks The list of [MediaData] instances representing the media data to be sent. + * Effectively, a message from the client to the model. */ - public suspend fun sendMediaStream( - mediaChunks: List, + internal class BidiGenerateContentClientContentSetup( + val turns: List, + val turnComplete: Boolean ) { - val jsonString = - Json.encodeToString(MediaStreamingSetup(mediaChunks.map { it.toInternal() }).toInternal()) - session?.send(Frame.Text(jsonString)) + @Serializable + internal class Internal(val clientContent: BidiGenerateContentClientContent) { + @Serializable + internal data class BidiGenerateContentClientContent( + val turns: List, + val turnComplete: Boolean + ) + } + + fun toInternal() = Internal(Internal.BidiGenerateContentClientContent(turns, turnComplete)) } /** - * Sends data to the server. Calling this after [startAudioConversation] will play the response - * audio immediately. + * Incremental server update generated by the model in response to client messages. * - * @param content Client [Content] to be sent to the server. + * Effectively, a message from the model to the client. */ - public suspend fun send(content: Content) { - val jsonString = - Json.encodeToString(ClientContentSetup(listOf(content.toInternal()), true).toInternal()) - session?.send(Frame.Text(jsonString)) + internal class BidiGenerateContentServerContentSetup( + val modelTurn: Content.Internal?, + val turnComplete: Boolean?, + val interrupted: Boolean? + ) { + @Serializable + internal class Internal(val serverContent: BidiGenerateContentServerContent) { + @Serializable + internal data class BidiGenerateContentServerContent( + val modelTurn: Content.Internal?, + val turnComplete: Boolean?, + val interrupted: Boolean? + ) + } + + fun toInternal() = + Internal(Internal.BidiGenerateContentServerContent(modelTurn, turnComplete, interrupted)) } /** - * Sends text to the server. Calling this after [startAudioConversation] will play the response - * audio immediately. - * - * @param text Text to be sent to the server. + * Request for the client to execute the provided function calls and return the responses with the + * matched `id`s. */ - public suspend fun send(text: String) { - send(Content.Builder().text(text).build()) + internal data class BidiGenerateContentToolCallSetup( + val functionCalls: List + ) { + @Serializable + internal class Internal(val toolCall: BidiGenerateContentToolCall) { + @Serializable + internal data class BidiGenerateContentToolCall( + val functionCalls: List + ) + } + + fun toInternal(): Internal { + return Internal(Internal.BidiGenerateContentToolCall(functionCalls)) + } + } + + /** Client generated responses to a [BidiGenerateContentToolCallSetup]. */ + internal class BidiGenerateContentToolResponseSetup( + val functionResponses: List + ) { + @Serializable + internal data class Internal(val toolResponse: BidiGenerateContentToolResponse) { + @Serializable + internal data class BidiGenerateContentToolResponse( + val functionResponses: List + ) + } + + fun toInternal() = Internal(Internal.BidiGenerateContentToolResponse(functionResponses)) } /** - * Closes the client session. + * User input that is sent to the model in real time. * - * After this is called, the session object becomes unusable. To interact with the server again, - * you must create a new session using [LiveGenerativeModel]. + * End of turn is derived from user activity (eg; end of speech). */ - public suspend fun close() { - session?.close() + internal class BidiGenerateContentRealtimeInputSetup(val mediaChunks: List) { + @Serializable + internal class Internal(val realtimeInput: BidiGenerateContentRealtimeInput) { + @Serializable + internal data class BidiGenerateContentRealtimeInput( + val mediaChunks: List + ) + } + fun toInternal() = Internal(Internal.BidiGenerateContentRealtimeInput(mediaChunks)) + } + + private companion object { + val TAG = LiveSession::class.java.simpleName + val MIN_BUFFER_SIZE = + AudioTrack.getMinBufferSize( + 24000, + AudioFormat.CHANNEL_OUT_MONO, + AudioFormat.ENCODING_PCM_16BIT + ) } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index 21d3c0edc6c..b2538a8d6a0 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -71,8 +71,11 @@ public class InlineDataPart(public val inlineData: ByteArray, public val mimeTyp * @param name the name of the function to call * @param args the function parameters and values as a [Map] */ -public class FunctionCallPart(public val name: String, public val args: Map) : - Part { +// TODO(b/410040441): Support id property +public class FunctionCallPart( + public val name: String, + public val args: Map, +) : Part { @Serializable internal data class Internal(val functionCall: FunctionCall) : InternalPart { @@ -88,6 +91,7 @@ public class FunctionCallPart(public val name: String, public val args: Map Unit * @param block The test contents themselves, with the [CommonTestScope] implicitly provided * @see CommonTestScope */ +@OptIn(PublicPreviewAPI::class) internal fun commonTest( status: HttpStatusCode = HttpStatusCode.OK, requestOptions: RequestOptions = RequestOptions(),