Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
package com.google.firebase.ai.common.util

import android.media.AudioRecord
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.yield

/**
* The minimum buffer size for this instance.
Expand All @@ -40,15 +38,13 @@ internal fun AudioRecord.readAsFlow() = flow {

while (true) {
if (recordingState != AudioRecord.RECORDSTATE_RECORDING) {
// TODO(vguthal): Investigate if both yield and delay are required.
delay(10.milliseconds)
yield()
delay(0)
continue
}
val bytesRead = read(buffer, 0, buffer.size)
if (bytesRead > 0) {
emit(buffer.copyOf(bytesRead))
}
yield()
delay(0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ internal class AudioHelper(
fun build(): AudioHelper {
val playbackTrack =
AudioTrack(
AudioAttributes.Builder().setUsage(AudioAttributes.USAGE_VOICE_COMMUNICATION).build(),
AudioAttributes.Builder()
.setUsage(AudioAttributes.USAGE_MEDIA)
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.build(),
AudioFormat.Builder()
.setSampleRate(24000)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
package com.google.firebase.ai.type

import android.Manifest.permission.RECORD_AUDIO
import android.annotation.SuppressLint
import android.content.pm.PackageManager
import android.media.AudioFormat
import android.media.AudioTrack
import android.os.Process
import android.os.StrictMode
import android.os.StrictMode.ThreadPolicy
import android.util.Log
import androidx.annotation.RequiresPermission
import androidx.core.content.ContextCompat
import com.google.firebase.BuildConfig
import com.google.firebase.FirebaseApp
import com.google.firebase.ai.common.JSON
import com.google.firebase.ai.common.util.CancelledCoroutineScope
Expand All @@ -34,21 +39,27 @@ import io.ktor.websocket.Frame
import io.ktor.websocket.close
import io.ktor.websocket.readBytes
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.yield
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
Expand All @@ -65,11 +76,21 @@ internal constructor(
private val firebaseApp: FirebaseApp,
) {
/**
* Coroutine scope that we batch data on for [startAudioConversation].
* Coroutine scope that we batch data on for network related behavior.
*
* Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
*/
private var scope = CancelledCoroutineScope
private var networkScope = CancelledCoroutineScope

/**
* Coroutine scope that we batch data on for audio recording and playback.
*
* Separate from [networkScope] to ensure interchanging of dispatchers doesn't cause any deadlocks
* or issues.
*
* Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
*/
private var audioScope = CancelledCoroutineScope

/**
* Playback audio data sent from the model.
Expand Down Expand Up @@ -134,10 +155,6 @@ internal constructor(
* function call. The [FunctionResponsePart] that the callback function returns will be
* automatically sent to the model.
*
* @param transcriptHandler A callback function that is invoked whenever the model receives a
* transcript. The first [Transcription] object is the input transcription, and the second is the
* output transcription.
*
* @param enableInterruptions If enabled, allows the user to speak over or interrupt the model's
* ongoing reply.
*
Expand All @@ -159,16 +176,17 @@ internal constructor(
}

FirebaseAIException.catchAsync {
if (scope.isActive) {
if (networkScope.isActive || audioScope.isActive) {
Log.w(
TAG,
"startAudioConversation called after the recording has already started. " +
"Call stopAudioConversation to close the previous connection."
)
return@catchAsync
}

scope = CoroutineScope(blockingDispatcher + childJob())
networkScope =
CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network"))
audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio"))
audioHelper = AudioHelper.build()

recordUserAudio()
Expand All @@ -188,7 +206,8 @@ internal constructor(
FirebaseAIException.catch {
if (!startedReceiving.getAndSet(false)) return@catch

scope.cancel()
networkScope.cancel()
audioScope.cancel()
playBackQueue.clear()

audioHelper?.release()
Expand Down Expand Up @@ -231,7 +250,7 @@ internal constructor(
)
}
?.let { emit(it.toPublic()) }
yield()
delay(0)
}
}
.onCompletion { stopAudioConversation() }
Expand All @@ -258,7 +277,8 @@ internal constructor(
FirebaseAIException.catch {
if (!startedReceiving.getAndSet(false)) return@catch

scope.cancel()
networkScope.cancel()
audioScope.cancel()
playBackQueue.clear()

audioHelper?.release()
Expand Down Expand Up @@ -403,18 +423,22 @@ internal constructor(
audioHelper
?.listenToRecording()
?.buffer(UNLIMITED)
?.flowOn(audioDispatcher)
?.accumulateUntil(MIN_BUFFER_SIZE)
?.onEach { sendAudioRealtime(InlineData(it, "audio/pcm")) }
?.onEach {
sendAudioRealtime(InlineData(it, "audio/pcm"))
delay(0)
}
?.catch { throw FirebaseAIException.from(it) }
?.launchIn(scope)
?.launchIn(networkScope)
}

/**
* Processes responses from the model during an audio conversation.
*
* Audio messages are added to [playBackQueue].
*
* Launched asynchronously on [scope].
* Launched asynchronously on [networkScope].
*
* @param functionCallHandler A callback function that is invoked whenever the server receives a
* function call.
Expand Down Expand Up @@ -471,18 +495,18 @@ internal constructor(
}
}
}
.launchIn(scope)
.launchIn(networkScope)
}

/**
* Listens for playback data from the model and plays the audio.
*
* Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received.
*
* Launched asynchronously on [scope].
* Launched asynchronously on [networkScope].
*/
private fun listenForModelPlayback(enableInterruptions: Boolean = false) {
scope.launch {
audioScope.launch {
while (isActive) {
val playbackData = playBackQueue.poll()
if (playbackData == null) {
Expand All @@ -491,14 +515,14 @@ internal constructor(
if (!enableInterruptions) {
audioHelper?.resumeRecording()
}
yield()
delay(0)
} else {
/**
* We pause the recording while the model is speaking to avoid interrupting it because of
* no echo cancellation
*/
// TODO(b/408223520): Conditionally pause when param is added
if (enableInterruptions != true) {
if (!enableInterruptions) {
audioHelper?.pauseRecording()
}
audioHelper?.playAudio(playbackData)
Expand Down Expand Up @@ -583,5 +607,38 @@ internal constructor(
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_16BIT
)
@SuppressLint("ThreadPoolCreation")
val audioDispatcher =
Executors.newCachedThreadPool(AudioThreadFactory()).asCoroutineDispatcher()
}
}

internal class AudioThreadFactory : ThreadFactory {
private val threadCount = AtomicLong()
private val policy: ThreadPolicy = audioPolicy()

override fun newThread(task: Runnable?): Thread? {
val thread =
DEFAULT.newThread {
Process.setThreadPriority(Process.THREAD_PRIORITY_AUDIO)
StrictMode.setThreadPolicy(policy)
task?.run()
}
thread.name = "Firebase Audio Thread #${threadCount.andIncrement}"
return thread
}

companion object {
val DEFAULT: ThreadFactory = Executors.defaultThreadFactory()

private fun audioPolicy(): ThreadPolicy {
val builder = ThreadPolicy.Builder().detectNetwork()

if (BuildConfig.DEBUG) {
builder.penaltyDeath()
}

return builder.penaltyLog().build()
}
}
}
Loading