Skip to content

Commit 6ad6e02

Browse files
committed
Add dispatcher
1 parent 082c510 commit 6ad6e02

File tree

1 file changed

+67
-13
lines changed

1 file changed

+67
-13
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717
package com.google.firebase.ai.type
1818

1919
import android.Manifest.permission.RECORD_AUDIO
20+
import android.annotation.SuppressLint
2021
import android.content.pm.PackageManager
2122
import android.media.AudioFormat
2223
import android.media.AudioTrack
24+
import android.os.Process
25+
import android.os.StrictMode
26+
import android.os.StrictMode.ThreadPolicy
2327
import android.util.Log
2428
import androidx.annotation.RequiresPermission
2529
import androidx.core.content.ContextCompat
30+
import com.google.firebase.BuildConfig
2631
import com.google.firebase.FirebaseApp
2732
import com.google.firebase.ai.common.JSON
2833
import com.google.firebase.ai.common.util.CancelledCoroutineScope
@@ -38,14 +43,15 @@ import java.util.concurrent.ConcurrentLinkedQueue
3843
import java.util.concurrent.atomic.AtomicBoolean
3944
import kotlin.coroutines.CoroutineContext
4045
import kotlinx.coroutines.CoroutineScope
41-
import kotlinx.coroutines.Dispatchers
46+
import kotlinx.coroutines.asCoroutineDispatcher
4247
import kotlinx.coroutines.cancel
4348
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
4449
import kotlinx.coroutines.delay
4550
import kotlinx.coroutines.flow.Flow
4651
import kotlinx.coroutines.flow.buffer
4752
import kotlinx.coroutines.flow.catch
4853
import kotlinx.coroutines.flow.flow
54+
import kotlinx.coroutines.flow.flowOn
4955
import kotlinx.coroutines.flow.launchIn
5056
import kotlinx.coroutines.flow.onCompletion
5157
import kotlinx.coroutines.flow.onEach
@@ -55,6 +61,9 @@ import kotlinx.serialization.ExperimentalSerializationApi
5561
import kotlinx.serialization.Serializable
5662
import kotlinx.serialization.encodeToString
5763
import kotlinx.serialization.json.Json
64+
import java.util.concurrent.Executors
65+
import java.util.concurrent.ThreadFactory
66+
import java.util.concurrent.atomic.AtomicLong
5867

5968
/** Represents a live WebSocket session capable of streaming content to and from the server. */
6069
@PublicPreviewAPI
@@ -67,11 +76,21 @@ internal constructor(
6776
private val firebaseApp: FirebaseApp,
6877
) {
6978
/**
70-
* Coroutine scope that we batch data on for [startAudioConversation].
79+
* Coroutine scope that we batch data on for network related behavior.
7180
*
7281
* Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
7382
*/
74-
private var scope = CancelledCoroutineScope
83+
private var networkScope = CancelledCoroutineScope
84+
85+
/**
86+
* Coroutine scope that we batch data on for audio recording and playback.
87+
*
88+
* Separate from [networkScope] to ensure interchanging of dispatchers doesn't
89+
* cause any deadlocks or issues.
90+
*
91+
* Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
92+
*/
93+
private var audioScope = CancelledCoroutineScope
7594

7695
/**
7796
* Playback audio data sent from the model.
@@ -129,16 +148,16 @@ internal constructor(
129148
}
130149

131150
FirebaseAIException.catchAsync {
132-
if (scope.isActive) {
151+
if (networkScope.isActive) {
133152
Log.w(
134153
TAG,
135154
"startAudioConversation called after the recording has already started. " +
136155
"Call stopAudioConversation to close the previous connection."
137156
)
138157
return@catchAsync
139158
}
140-
// TODO: maybe it should be THREAD_PRIORITY_AUDIO anyways for playback and recording (not network though)
141-
scope = CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Scope"))
159+
networkScope = CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network"))
160+
audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio"))
142161
audioHelper = AudioHelper.build()
143162

144163
recordUserAudio()
@@ -158,7 +177,8 @@ internal constructor(
158177
FirebaseAIException.catch {
159178
if (!startedReceiving.getAndSet(false)) return@catch
160179

161-
scope.cancel()
180+
networkScope.cancel()
181+
audioScope.cancel()
162182
playBackQueue.clear()
163183

164184
audioHelper?.release()
@@ -228,7 +248,8 @@ internal constructor(
228248
FirebaseAIException.catch {
229249
if (!startedReceiving.getAndSet(false)) return@catch
230250

231-
scope.cancel()
251+
networkScope.cancel()
252+
audioScope.cancel()
232253
playBackQueue.clear()
233254

234255
audioHelper?.release()
@@ -325,21 +346,22 @@ internal constructor(
325346
audioHelper
326347
?.listenToRecording()
327348
?.buffer(UNLIMITED)
349+
?.flowOn(audioDispatcher)
328350
?.accumulateUntil(MIN_BUFFER_SIZE)
329351
?.onEach {
330352
sendMediaStream(listOf(MediaData(it, "audio/pcm")))
331353
delay(0)
332354
}
333355
?.catch { throw FirebaseAIException.from(it) }
334-
?.launchIn(scope)
356+
?.launchIn(networkScope)
335357
}
336358

337359
/**
338360
* Processes responses from the model during an audio conversation.
339361
*
340362
* Audio messages are added to [playBackQueue].
341363
*
342-
* Launched asynchronously on [scope].
364+
* Launched asynchronously on [networkScope].
343365
*
344366
* @param functionCallHandler A callback function that is invoked whenever the server receives a
345367
* function call.
@@ -393,18 +415,18 @@ internal constructor(
393415
}
394416
}
395417
}
396-
.launchIn(scope)
418+
.launchIn(networkScope)
397419
}
398420

399421
/**
400422
* Listens for playback data from the model and plays the audio.
401423
*
402424
* Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received.
403425
*
404-
* Launched asynchronously on [scope].
426+
* Launched asynchronously on [networkScope].
405427
*/
406428
private fun listenForModelPlayback(enableInterruptions: Boolean = false) {
407-
scope.launch {
429+
audioScope.launch {
408430
while (isActive) {
409431
val playbackData = playBackQueue.poll()
410432
if (playbackData == null) {
@@ -490,5 +512,37 @@ internal constructor(
490512
AudioFormat.CHANNEL_OUT_MONO,
491513
AudioFormat.ENCODING_PCM_16BIT
492514
)
515+
@SuppressLint("ThreadPoolCreation")
516+
val audioDispatcher = Executors.newCachedThreadPool(AudioThreadFactory()).asCoroutineDispatcher()
493517
}
494518
}
519+
520+
internal class AudioThreadFactory : ThreadFactory {
521+
private val threadCount = AtomicLong()
522+
private val policy: ThreadPolicy = audioPolicy()
523+
524+
override fun newThread(task: Runnable?): Thread? {
525+
val thread =
526+
DEFAULT.newThread {
527+
Process.setThreadPriority(Process.THREAD_PRIORITY_AUDIO)
528+
StrictMode.setThreadPolicy(policy)
529+
task?.run()
530+
}
531+
thread.name = "Firebase Audio Thread #${threadCount.andIncrement}"
532+
return thread
533+
}
534+
535+
companion object {
536+
val DEFAULT: ThreadFactory = Executors.defaultThreadFactory()
537+
538+
private fun audioPolicy(): ThreadPolicy {
539+
val builder = ThreadPolicy.Builder().detectNetwork()
540+
541+
if (BuildConfig.DEBUG) {
542+
builder.penaltyDeath()
543+
}
544+
545+
return builder.penaltyLog().build()
546+
}
547+
}
548+
}

0 commit comments

Comments
 (0)