diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt index c020e94f415..bb1d28e9746 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt @@ -17,12 +17,8 @@ package com.google.firebase.ai.common.util import android.media.AudioRecord -import kotlin.time.Duration.Companion.milliseconds import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.callbackFlow import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.isActive -import kotlinx.coroutines.yield /** * The minimum buffer size for this instance. diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt index edeb7c332f7..06b4a3efe25 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt @@ -162,7 +162,10 @@ internal class AudioHelper( fun build(): AudioHelper { val playbackTrack = AudioTrack( - AudioAttributes.Builder().setUsage(AudioAttributes.USAGE_MEDIA).setContentType(AudioAttributes.CONTENT_TYPE_SPEECH).build(), + AudioAttributes.Builder() + .setUsage(AudioAttributes.USAGE_MEDIA) + .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) + .build(), AudioFormat.Builder() .setSampleRate(24000) .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt index a5b169d12aa..6cc4ef2c4b4 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt @@ -17,12 +17,17 @@ package com.google.firebase.ai.type import android.Manifest.permission.RECORD_AUDIO +import android.annotation.SuppressLint import android.content.pm.PackageManager import android.media.AudioFormat import android.media.AudioTrack +import android.os.Process +import android.os.StrictMode +import android.os.StrictMode.ThreadPolicy import android.util.Log import androidx.annotation.RequiresPermission import androidx.core.content.ContextCompat +import com.google.firebase.BuildConfig import com.google.firebase.FirebaseApp import com.google.firebase.ai.common.JSON import com.google.firebase.ai.common.util.CancelledCoroutineScope @@ -33,12 +38,15 @@ import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession import io.ktor.websocket.Frame import io.ktor.websocket.close import io.ktor.websocket.readBytes -import kotlinx.coroutines.CoroutineName import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.Executors +import java.util.concurrent.ThreadFactory import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicLong import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.delay @@ -46,6 +54,7 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.buffer import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onEach @@ -67,11 +76,21 @@ internal constructor( private val firebaseApp: FirebaseApp, ) { /** - * Coroutine scope that we batch data on for [startAudioConversation]. + * Coroutine scope that we batch data on for network related behavior. + * + * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope. + */ + private var networkScope = CancelledCoroutineScope + + /** + * Coroutine scope that we batch data on for audio recording and playback. + * + * Separate from [networkScope] to ensure interchanging of dispatchers doesn't cause any deadlocks + * or issues. * * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope. */ - private var scope = CancelledCoroutineScope + private var audioScope = CancelledCoroutineScope /** * Playback audio data sent from the model. @@ -129,7 +148,7 @@ internal constructor( } FirebaseAIException.catchAsync { - if (scope.isActive) { + if (networkScope.isActive || audioScope.isActive) { Log.w( TAG, "startAudioConversation called after the recording has already started. " + @@ -137,8 +156,9 @@ internal constructor( ) return@catchAsync } - // TODO: maybe it should be THREAD_PRIORITY_AUDIO anyways for playback and recording (not network though) - scope = CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Scope")) + networkScope = + CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network")) + audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio")) audioHelper = AudioHelper.build() recordUserAudio() @@ -158,7 +178,8 @@ internal constructor( FirebaseAIException.catch { if (!startedReceiving.getAndSet(false)) return@catch - scope.cancel() + networkScope.cancel() + audioScope.cancel() playBackQueue.clear() audioHelper?.release() @@ -228,7 +249,8 @@ internal constructor( FirebaseAIException.catch { if (!startedReceiving.getAndSet(false)) return@catch - scope.cancel() + networkScope.cancel() + audioScope.cancel() playBackQueue.clear() audioHelper?.release() @@ -325,13 +347,14 @@ internal constructor( audioHelper ?.listenToRecording() ?.buffer(UNLIMITED) + ?.flowOn(audioDispatcher) ?.accumulateUntil(MIN_BUFFER_SIZE) ?.onEach { sendMediaStream(listOf(MediaData(it, "audio/pcm"))) delay(0) } ?.catch { throw FirebaseAIException.from(it) } - ?.launchIn(scope) + ?.launchIn(networkScope) } /** @@ -339,7 +362,7 @@ internal constructor( * * Audio messages are added to [playBackQueue]. * - * Launched asynchronously on [scope]. + * Launched asynchronously on [networkScope]. * * @param functionCallHandler A callback function that is invoked whenever the server receives a * function call. @@ -393,7 +416,7 @@ internal constructor( } } } - .launchIn(scope) + .launchIn(networkScope) } /** @@ -401,10 +424,10 @@ internal constructor( * * Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received. * - * Launched asynchronously on [scope]. + * Launched asynchronously on [networkScope]. */ private fun listenForModelPlayback(enableInterruptions: Boolean = false) { - scope.launch { + audioScope.launch { while (isActive) { val playbackData = playBackQueue.poll() if (playbackData == null) { @@ -490,5 +513,38 @@ internal constructor( AudioFormat.CHANNEL_OUT_MONO, AudioFormat.ENCODING_PCM_16BIT ) + @SuppressLint("ThreadPoolCreation") + val audioDispatcher = + Executors.newCachedThreadPool(AudioThreadFactory()).asCoroutineDispatcher() + } +} + +internal class AudioThreadFactory : ThreadFactory { + private val threadCount = AtomicLong() + private val policy: ThreadPolicy = audioPolicy() + + override fun newThread(task: Runnable?): Thread? { + val thread = + DEFAULT.newThread { + Process.setThreadPriority(Process.THREAD_PRIORITY_AUDIO) + StrictMode.setThreadPolicy(policy) + task?.run() + } + thread.name = "Firebase Audio Thread #${threadCount.andIncrement}" + return thread + } + + companion object { + val DEFAULT: ThreadFactory = Executors.defaultThreadFactory() + + private fun audioPolicy(): ThreadPolicy { + val builder = ThreadPolicy.Builder().detectNetwork() + + if (BuildConfig.DEBUG) { + builder.penaltyDeath() + } + + return builder.penaltyLog().build() + } } }