1717package com.google.firebase.ai.type
1818
1919import android.Manifest.permission.RECORD_AUDIO
20+ import android.annotation.SuppressLint
2021import android.content.pm.PackageManager
2122import android.media.AudioFormat
2223import android.media.AudioTrack
24+ import android.os.Process
25+ import android.os.StrictMode
26+ import android.os.StrictMode.ThreadPolicy
2327import android.util.Log
2428import androidx.annotation.RequiresPermission
2529import androidx.core.content.ContextCompat
30+ import com.google.firebase.BuildConfig
2631import com.google.firebase.FirebaseApp
2732import com.google.firebase.ai.common.JSON
2833import com.google.firebase.ai.common.util.CancelledCoroutineScope
@@ -38,14 +43,15 @@ import java.util.concurrent.ConcurrentLinkedQueue
3843import java.util.concurrent.atomic.AtomicBoolean
3944import kotlin.coroutines.CoroutineContext
4045import kotlinx.coroutines.CoroutineScope
41- import kotlinx.coroutines.Dispatchers
46+ import kotlinx.coroutines.asCoroutineDispatcher
4247import kotlinx.coroutines.cancel
4348import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
4449import kotlinx.coroutines.delay
4550import kotlinx.coroutines.flow.Flow
4651import kotlinx.coroutines.flow.buffer
4752import kotlinx.coroutines.flow.catch
4853import kotlinx.coroutines.flow.flow
54+ import kotlinx.coroutines.flow.flowOn
4955import kotlinx.coroutines.flow.launchIn
5056import kotlinx.coroutines.flow.onCompletion
5157import kotlinx.coroutines.flow.onEach
@@ -55,6 +61,9 @@ import kotlinx.serialization.ExperimentalSerializationApi
5561import kotlinx.serialization.Serializable
5662import kotlinx.serialization.encodeToString
5763import kotlinx.serialization.json.Json
64+ import java.util.concurrent.Executors
65+ import java.util.concurrent.ThreadFactory
66+ import java.util.concurrent.atomic.AtomicLong
5867
5968/* * Represents a live WebSocket session capable of streaming content to and from the server. */
6069@PublicPreviewAPI
@@ -67,11 +76,21 @@ internal constructor(
6776 private val firebaseApp: FirebaseApp ,
6877) {
6978 /* *
70- * Coroutine scope that we batch data on for [startAudioConversation] .
79+ * Coroutine scope that we batch data on for network related behavior .
7180 *
7281 * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
7382 */
74- private var scope = CancelledCoroutineScope
83+ private var networkScope = CancelledCoroutineScope
84+
85+ /* *
86+ * Coroutine scope that we batch data on for audio recording and playback.
87+ *
88+ * Separate from [networkScope] to ensure interchanging of dispatchers doesn't
89+ * cause any deadlocks or issues.
90+ *
91+ * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
92+ */
93+ private var audioScope = CancelledCoroutineScope
7594
7695 /* *
7796 * Playback audio data sent from the model.
@@ -129,16 +148,16 @@ internal constructor(
129148 }
130149
131150 FirebaseAIException .catchAsync {
132- if (scope .isActive) {
151+ if (networkScope .isActive) {
133152 Log .w(
134153 TAG ,
135154 " startAudioConversation called after the recording has already started. " +
136155 " Call stopAudioConversation to close the previous connection."
137156 )
138157 return @catchAsync
139158 }
140- // TODO: maybe it should be THREAD_PRIORITY_AUDIO anyways for playback and recording (not network though )
141- scope = CoroutineScope (blockingDispatcher + childJob() + CoroutineName (" LiveSession Scope " ))
159+ networkScope = CoroutineScope (blockingDispatcher + childJob() + CoroutineName ( " LiveSession Network " ) )
160+ audioScope = CoroutineScope (audioDispatcher + childJob() + CoroutineName (" LiveSession Audio " ))
142161 audioHelper = AudioHelper .build()
143162
144163 recordUserAudio()
@@ -158,7 +177,8 @@ internal constructor(
158177 FirebaseAIException .catch {
159178 if (! startedReceiving.getAndSet(false )) return @catch
160179
161- scope.cancel()
180+ networkScope.cancel()
181+ audioScope.cancel()
162182 playBackQueue.clear()
163183
164184 audioHelper?.release()
@@ -228,7 +248,8 @@ internal constructor(
228248 FirebaseAIException .catch {
229249 if (! startedReceiving.getAndSet(false )) return @catch
230250
231- scope.cancel()
251+ networkScope.cancel()
252+ audioScope.cancel()
232253 playBackQueue.clear()
233254
234255 audioHelper?.release()
@@ -325,21 +346,22 @@ internal constructor(
325346 audioHelper
326347 ?.listenToRecording()
327348 ?.buffer(UNLIMITED )
349+ ?.flowOn(audioDispatcher)
328350 ?.accumulateUntil(MIN_BUFFER_SIZE )
329351 ?.onEach {
330352 sendMediaStream(listOf (MediaData (it, " audio/pcm" )))
331353 delay(0 )
332354 }
333355 ?.catch { throw FirebaseAIException .from(it) }
334- ?.launchIn(scope )
356+ ?.launchIn(networkScope )
335357 }
336358
337359 /* *
338360 * Processes responses from the model during an audio conversation.
339361 *
340362 * Audio messages are added to [playBackQueue].
341363 *
342- * Launched asynchronously on [scope ].
364+ * Launched asynchronously on [networkScope ].
343365 *
344366 * @param functionCallHandler A callback function that is invoked whenever the server receives a
345367 * function call.
@@ -393,18 +415,18 @@ internal constructor(
393415 }
394416 }
395417 }
396- .launchIn(scope )
418+ .launchIn(networkScope )
397419 }
398420
399421 /* *
400422 * Listens for playback data from the model and plays the audio.
401423 *
402424 * Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received.
403425 *
404- * Launched asynchronously on [scope ].
426+ * Launched asynchronously on [networkScope ].
405427 */
406428 private fun listenForModelPlayback (enableInterruptions : Boolean = false) {
407- scope .launch {
429+ audioScope .launch {
408430 while (isActive) {
409431 val playbackData = playBackQueue.poll()
410432 if (playbackData == null ) {
@@ -490,5 +512,37 @@ internal constructor(
490512 AudioFormat .CHANNEL_OUT_MONO ,
491513 AudioFormat .ENCODING_PCM_16BIT
492514 )
515+ @SuppressLint(" ThreadPoolCreation" )
516+ val audioDispatcher = Executors .newCachedThreadPool(AudioThreadFactory ()).asCoroutineDispatcher()
493517 }
494518}
519+
520+ internal class AudioThreadFactory : ThreadFactory {
521+ private val threadCount = AtomicLong ()
522+ private val policy: ThreadPolicy = audioPolicy()
523+
524+ override fun newThread (task : Runnable ? ): Thread ? {
525+ val thread =
526+ DEFAULT .newThread {
527+ Process .setThreadPriority(Process .THREAD_PRIORITY_AUDIO )
528+ StrictMode .setThreadPolicy(policy)
529+ task?.run ()
530+ }
531+ thread.name = " Firebase Audio Thread #${threadCount.andIncrement} "
532+ return thread
533+ }
534+
535+ companion object {
536+ val DEFAULT : ThreadFactory = Executors .defaultThreadFactory()
537+
538+ private fun audioPolicy (): ThreadPolicy {
539+ val builder = ThreadPolicy .Builder ().detectNetwork()
540+
541+ if (BuildConfig .DEBUG ) {
542+ builder.penaltyDeath()
543+ }
544+
545+ return builder.penaltyLog().build()
546+ }
547+ }
548+ }
0 commit comments