Skip to content

Commit 849cd73

Browse files
authored
Refactor Assist audio recording to fix race condition using Channel (#6211)
1 parent bcf4632 commit 849cd73

File tree

4 files changed

+380
-37
lines changed

4 files changed

+380
-37
lines changed

app/src/main/kotlin/io/homeassistant/companion/android/assist/AssistViewModel.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ class AssistViewModel @Inject constructor(
280280
}
281281

282282
if (recording) {
283-
if (!recorderProactive) setupRecorderQueue()
283+
if (!recorderProactive) setupRecorder()
284284
inputMode = AssistInputMode.VOICE_ACTIVE
285285
if (proactive == true) _conversation.add(AssistMessage("", isInput = true))
286286
if (proactive != true) runAssistPipeline(null)

common/src/main/kotlin/io/homeassistant/companion/android/common/assist/AssistViewModelBase.kt

Lines changed: 98 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package io.homeassistant.companion.android.common.assist
22

33
import android.app.Application
44
import android.content.pm.PackageManager
5+
import androidx.annotation.VisibleForTesting
6+
import androidx.annotation.VisibleForTesting.Companion.PROTECTED
57
import androidx.lifecycle.AndroidViewModel
68
import androidx.lifecycle.viewModelScope
79
import io.homeassistant.companion.android.common.R
@@ -17,16 +19,20 @@ import io.homeassistant.companion.android.common.data.websocket.impl.entities.As
1719
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineTtsEnd
1820
import io.homeassistant.companion.android.common.util.AudioRecorder
1921
import io.homeassistant.companion.android.common.util.AudioUrlPlayer
22+
import io.homeassistant.companion.android.common.util.FailFast
2023
import io.homeassistant.companion.android.common.util.PlaybackState
2124
import io.homeassistant.companion.android.util.UrlUtil
2225
import java.util.concurrent.atomic.AtomicBoolean
26+
import kotlinx.coroutines.CompletableDeferred
2327
import kotlinx.coroutines.ExperimentalCoroutinesApi
2428
import kotlinx.coroutines.Job
29+
import kotlinx.coroutines.channels.Channel
2530
import kotlinx.coroutines.flow.Flow
2631
import kotlinx.coroutines.flow.emptyFlow
2732
import kotlinx.coroutines.flow.first
2833
import kotlinx.coroutines.flow.flatMapLatest
2934
import kotlinx.coroutines.launch
35+
import timber.log.Timber
3036

3137
sealed interface AssistEvent {
3238
sealed class Message(val message: String) : AssistEvent {
@@ -64,12 +70,31 @@ abstract class AssistViewModelBase(
6470
protected var selectedServerId = ServerManager.SERVER_ID_ACTIVE
6571

6672
protected var recorderProactive = false
73+
74+
/**
75+
* Parent job managing the audio recording pipeline. Contains both the producer
76+
* (collecting from [AudioRecorder.audioBytes]) and consumer (forwarding to server).
77+
* Joining this job ensures all buffered audio has been sent before cleanup.
78+
*/
6779
private var recorderJob: Job? = null
68-
private var recorderQueue: MutableList<ByteArray>? = null
80+
81+
/**
82+
* Child job that collects audio bytes from [AudioRecorder.audioBytes] SharedFlow
83+
* and buffers them in a channel. Cancelled separately from [recorderJob] to stop
84+
* collecting while still allowing the consumer to drain buffered data.
85+
*/
86+
private var producerJob: Job? = null
87+
88+
/**
89+
* Signals when the server is ready to receive audio data. Completed with the
90+
* binary handler ID when [handleSttStart] is called. The consumer coroutine
91+
* awaits this before forwarding buffered audio to the server.
92+
*/
93+
private var sttReady: CompletableDeferred<Int>? = null
6994
protected val hasMicrophone = app.packageManager.hasSystemFeature(PackageManager.FEATURE_MICROPHONE)
7095
protected var hasPermission = false
7196

72-
private var binaryHandlerId: Int? = null
97+
@VisibleForTesting var binaryHandlerId: Int? = null
7398
private var conversationId: String? = null
7499
private var continueConversation = AtomicBoolean(false)
75100

@@ -183,14 +208,8 @@ abstract class AssistViewModelBase(
183208
}
184209

185210
private fun handleSttStart() {
186-
viewModelScope.launch {
187-
binaryHandlerId?.let { id ->
188-
// Manually loop here to avoid the queue being reset too soon
189-
recorderQueue?.forEach { data ->
190-
serverManager.webSocketRepository(selectedServerId).sendVoiceData(id, data)
191-
}
192-
}
193-
recorderQueue = null
211+
binaryHandlerId?.let { id ->
212+
sttReady?.complete(id)
194213
}
195214
}
196215

@@ -243,20 +262,45 @@ abstract class AssistViewModelBase(
243262
return true
244263
}
245264

246-
protected fun setupRecorderQueue() {
247-
recorderQueue = mutableListOf()
265+
/**
266+
* Sets up audio recording and buffering for voice input.
267+
*
268+
* Must be called before [runAssistPipelineInternal] for voice pipelines.
269+
* Audio data is buffered until the server is ready to receive, then all
270+
* buffered and subsequent audio is forwarded.
271+
*
272+
* Note: [AudioRecorder.audioBytes] is a SharedFlow which never completes, so we use
273+
* explicit channel management to allow graceful shutdown with buffer draining.
274+
*/
275+
@VisibleForTesting(otherwise = PROTECTED)
276+
fun setupRecorder() {
277+
Timber.d("Setting up recorder")
278+
sttReady = CompletableDeferred()
279+
248280
recorderJob = viewModelScope.launch {
249-
audioRecorder.audioBytes.collect {
250-
recorderQueue?.add(it) ?: sendVoiceData(it)
281+
val audioChannel = Channel<ByteArray>(Channel.UNLIMITED)
282+
283+
// Producer: collect from SharedFlow and buffer in channel
284+
producerJob = launch {
285+
audioRecorder.audioBytes.collect { data ->
286+
audioChannel.send(data)
287+
}
288+
}.apply {
289+
invokeOnCompletion {
290+
audioChannel.close()
291+
}
251292
}
252-
}
253-
}
254293

255-
private fun sendVoiceData(data: ByteArray) {
256-
binaryHandlerId?.let {
257-
viewModelScope.launch {
258-
// Launch to prevent blocking the output flow if the network is slow
259-
serverManager.webSocketRepository(selectedServerId).sendVoiceData(it, data)
294+
// Consumer: wait for STT to be ready, then forward all buffered and new data
295+
val handlerId = sttReady?.await() ?: run {
296+
FailFast.fail { "sttReady not set" }
297+
producerJob?.cancel()
298+
audioChannel.close()
299+
return@launch
300+
}
301+
302+
for (data in audioChannel) {
303+
serverManager.webSocketRepository(selectedServerId).sendVoiceData(handlerId, data)
260304
}
261305
}
262306
}
@@ -276,23 +320,42 @@ abstract class AssistViewModelBase(
276320
}
277321

278322
protected fun stopRecording(sendRecorded: Boolean = true) {
323+
stopAudioCapture()
324+
325+
binaryHandlerId?.let { handlerId ->
326+
finalizeRecording(handlerId, sendRecorded)
327+
} ?: clearRecorderState()
328+
329+
updateInputModeAfterRecording()
330+
}
331+
332+
private fun stopAudioCapture() {
279333
audioRecorder.stopRecording()
280-
recorderJob?.cancel()
281-
recorderJob = null
282-
if (binaryHandlerId != null) {
283-
viewModelScope.launch {
284-
if (sendRecorded) {
285-
recorderQueue?.forEach {
286-
sendVoiceData(it)
287-
}
288-
sendVoiceData(byteArrayOf()) // Empty message to indicate end of recording
289-
}
290-
recorderQueue = null
291-
binaryHandlerId = null
334+
producerJob?.cancel()
335+
producerJob = null
336+
}
337+
338+
private fun finalizeRecording(handlerId: Int, sendRecorded: Boolean) {
339+
viewModelScope.launch {
340+
if (sendRecorded) {
341+
recorderJob?.join()
342+
serverManager.webSocketRepository(selectedServerId).sendVoiceData(
343+
handlerId,
344+
byteArrayOf(),
345+
)
292346
}
293-
} else {
294-
recorderQueue = null
347+
clearRecorderState()
295348
}
349+
}
350+
351+
private fun clearRecorderState() {
352+
recorderJob?.cancel()
353+
recorderJob = null
354+
sttReady = null
355+
binaryHandlerId = null
356+
}
357+
358+
private fun updateInputModeAfterRecording() {
296359
if (getInput() == AssistInputMode.VOICE_ACTIVE) {
297360
setInput(if (recorderProactive) AssistInputMode.BLOCKED else AssistInputMode.VOICE_INACTIVE)
298361
}

0 commit comments

Comments
 (0)