diff --git a/firebase-ai/CHANGELOG.md b/firebase-ai/CHANGELOG.md index 053febea25a..abf0bf55c68 100644 --- a/firebase-ai/CHANGELOG.md +++ b/firebase-ai/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- [changed] Added better scheduling and louder output for Live API. +- [changed] Added support for input and output transcription. (#7482) - [feature] Added support for sending realtime audio and video in a `LiveSession`. - [changed] Removed redundant internal exception types. (#7475) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt index 6179c8b52e9..9f1bbd37260 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt @@ -17,10 +17,8 @@ package com.google.firebase.ai.common.util import android.media.AudioRecord -import kotlin.time.Duration.Companion.milliseconds import kotlinx.coroutines.delay import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.yield /** * The minimum buffer size for this instance. @@ -40,15 +38,17 @@ internal fun AudioRecord.readAsFlow() = flow { while (true) { if (recordingState != AudioRecord.RECORDSTATE_RECORDING) { - // TODO(vguthal): Investigate if both yield and delay are required. - delay(10.milliseconds) - yield() + // delay uses a different scheduler in the backend, so it's "stickier" in its enforcement when + // compared to yield. + delay(0) continue } val bytesRead = read(buffer, 0, buffer.size) if (bytesRead > 0) { emit(buffer.copyOf(bytesRead)) } - yield() + // delay uses a different scheduler in the backend, so it's "stickier" in its enforcement when + // compared to yield. + delay(0) } } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt index 08e90fc8538..06b4a3efe25 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt @@ -162,7 +162,10 @@ internal class AudioHelper( fun build(): AudioHelper { val playbackTrack = AudioTrack( - AudioAttributes.Builder().setUsage(AudioAttributes.USAGE_VOICE_COMMUNICATION).build(), + AudioAttributes.Builder() + .setUsage(AudioAttributes.USAGE_MEDIA) + .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) + .build(), AudioFormat.Builder() .setSampleRate(24000) .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt index 9e8b7d7f683..37d6f5011cb 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt @@ -17,12 +17,17 @@ package com.google.firebase.ai.type import android.Manifest.permission.RECORD_AUDIO +import android.annotation.SuppressLint import android.content.pm.PackageManager import android.media.AudioFormat import android.media.AudioTrack +import android.os.Process +import android.os.StrictMode +import android.os.StrictMode.ThreadPolicy import android.util.Log import androidx.annotation.RequiresPermission import androidx.core.content.ContextCompat +import com.google.firebase.BuildConfig import com.google.firebase.FirebaseApp import com.google.firebase.ai.common.JSON import com.google.firebase.ai.common.util.CancelledCoroutineScope @@ -34,21 +39,27 @@ import io.ktor.websocket.Frame import io.ktor.websocket.close import io.ktor.websocket.readBytes import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.Executors +import java.util.concurrent.ThreadFactory import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicLong import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED +import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.buffer import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.isActive import kotlinx.coroutines.launch -import kotlinx.coroutines.yield import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable import kotlinx.serialization.encodeToString @@ -65,11 +76,21 @@ internal constructor( private val firebaseApp: FirebaseApp, ) { /** - * Coroutine scope that we batch data on for [startAudioConversation]. + * Coroutine scope that we batch data on for network related behavior. * * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope. */ - private var scope = CancelledCoroutineScope + private var networkScope = CancelledCoroutineScope + + /** + * Coroutine scope that we batch data on for audio recording and playback. + * + * Separate from [networkScope] to ensure interchanging of dispatchers doesn't cause any deadlocks + * or issues. + * + * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope. + */ + private var audioScope = CancelledCoroutineScope /** * Playback audio data sent from the model. @@ -159,7 +180,7 @@ internal constructor( } FirebaseAIException.catchAsync { - if (scope.isActive) { + if (networkScope.isActive || audioScope.isActive) { Log.w( TAG, "startAudioConversation called after the recording has already started. " + @@ -167,8 +188,9 @@ internal constructor( ) return@catchAsync } - - scope = CoroutineScope(blockingDispatcher + childJob()) + networkScope = + CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network")) + audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio")) audioHelper = AudioHelper.build() recordUserAudio() @@ -188,7 +210,8 @@ internal constructor( FirebaseAIException.catch { if (!startedReceiving.getAndSet(false)) return@catch - scope.cancel() + networkScope.cancel() + audioScope.cancel() playBackQueue.clear() audioHelper?.release() @@ -231,7 +254,9 @@ internal constructor( ) } ?.let { emit(it.toPublic()) } - yield() + // delay uses a different scheduler in the backend, so it's "stickier" in its + // enforcement when compared to yield. + delay(0) } } .onCompletion { stopAudioConversation() } @@ -258,7 +283,8 @@ internal constructor( FirebaseAIException.catch { if (!startedReceiving.getAndSet(false)) return@catch - scope.cancel() + networkScope.cancel() + audioScope.cancel() playBackQueue.clear() audioHelper?.release() @@ -403,10 +429,16 @@ internal constructor( audioHelper ?.listenToRecording() ?.buffer(UNLIMITED) + ?.flowOn(audioDispatcher) ?.accumulateUntil(MIN_BUFFER_SIZE) - ?.onEach { sendAudioRealtime(InlineData(it, "audio/pcm")) } + ?.onEach { + sendAudioRealtime(InlineData(it, "audio/pcm")) + // delay uses a different scheduler in the backend, so it's "stickier" in its enforcement + // when compared to yield. + delay(0) + } ?.catch { throw FirebaseAIException.from(it) } - ?.launchIn(scope) + ?.launchIn(networkScope) } /** @@ -414,7 +446,7 @@ internal constructor( * * Audio messages are added to [playBackQueue]. * - * Launched asynchronously on [scope]. + * Launched asynchronously on [networkScope]. * * @param functionCallHandler A callback function that is invoked whenever the server receives a * function call. @@ -471,7 +503,7 @@ internal constructor( } } } - .launchIn(scope) + .launchIn(networkScope) } /** @@ -479,10 +511,10 @@ internal constructor( * * Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received. * - * Launched asynchronously on [scope]. + * Launched asynchronously on [networkScope]. */ private fun listenForModelPlayback(enableInterruptions: Boolean = false) { - scope.launch { + audioScope.launch { while (isActive) { val playbackData = playBackQueue.poll() if (playbackData == null) { @@ -491,14 +523,16 @@ internal constructor( if (!enableInterruptions) { audioHelper?.resumeRecording() } - yield() + // delay uses a different scheduler in the backend, so it's "stickier" in its enforcement + // when compared to yield. + delay(0) } else { /** * We pause the recording while the model is speaking to avoid interrupting it because of * no echo cancellation */ // TODO(b/408223520): Conditionally pause when param is added - if (enableInterruptions != true) { + if (!enableInterruptions) { audioHelper?.pauseRecording() } audioHelper?.playAudio(playbackData) @@ -583,5 +617,38 @@ internal constructor( AudioFormat.CHANNEL_OUT_MONO, AudioFormat.ENCODING_PCM_16BIT ) + @SuppressLint("ThreadPoolCreation") + val audioDispatcher = + Executors.newCachedThreadPool(AudioThreadFactory()).asCoroutineDispatcher() + } +} + +internal class AudioThreadFactory : ThreadFactory { + private val threadCount = AtomicLong() + private val policy: ThreadPolicy = audioPolicy() + + override fun newThread(task: Runnable?): Thread? { + val thread = + DEFAULT.newThread { + Process.setThreadPriority(Process.THREAD_PRIORITY_AUDIO) + StrictMode.setThreadPolicy(policy) + task?.run() + } + thread.name = "Firebase Audio Thread #${threadCount.andIncrement}" + return thread + } + + companion object { + val DEFAULT: ThreadFactory = Executors.defaultThreadFactory() + + private fun audioPolicy(): ThreadPolicy { + val builder = ThreadPolicy.Builder().detectNetwork() + + if (BuildConfig.DEBUG) { + builder.penaltyDeath() + } + + return builder.penaltyLog().build() + } } }