@@ -2,6 +2,8 @@ package io.homeassistant.companion.android.common.assist
22
33import android.app.Application
44import android.content.pm.PackageManager
5+ import androidx.annotation.VisibleForTesting
6+ import androidx.annotation.VisibleForTesting.Companion.PROTECTED
57import androidx.lifecycle.AndroidViewModel
68import androidx.lifecycle.viewModelScope
79import io.homeassistant.companion.android.common.R
@@ -17,16 +19,20 @@ import io.homeassistant.companion.android.common.data.websocket.impl.entities.As
1719import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineTtsEnd
1820import io.homeassistant.companion.android.common.util.AudioRecorder
1921import io.homeassistant.companion.android.common.util.AudioUrlPlayer
22+ import io.homeassistant.companion.android.common.util.FailFast
2023import io.homeassistant.companion.android.common.util.PlaybackState
2124import io.homeassistant.companion.android.util.UrlUtil
2225import java.util.concurrent.atomic.AtomicBoolean
26+ import kotlinx.coroutines.CompletableDeferred
2327import kotlinx.coroutines.ExperimentalCoroutinesApi
2428import kotlinx.coroutines.Job
29+ import kotlinx.coroutines.channels.Channel
2530import kotlinx.coroutines.flow.Flow
2631import kotlinx.coroutines.flow.emptyFlow
2732import kotlinx.coroutines.flow.first
2833import kotlinx.coroutines.flow.flatMapLatest
2934import kotlinx.coroutines.launch
35+ import timber.log.Timber
3036
3137sealed interface AssistEvent {
3238 sealed class Message (val message : String ) : AssistEvent {
@@ -64,12 +70,31 @@ abstract class AssistViewModelBase(
6470 protected var selectedServerId = ServerManager .SERVER_ID_ACTIVE
6571
6672 protected var recorderProactive = false
73+
74+ /* *
75+ * Parent job managing the audio recording pipeline. Contains both the producer
76+ * (collecting from [AudioRecorder.audioBytes]) and consumer (forwarding to server).
77+ * Joining this job ensures all buffered audio has been sent before cleanup.
78+ */
6779 private var recorderJob: Job ? = null
68- private var recorderQueue: MutableList <ByteArray >? = null
80+
81+ /* *
82+ * Child job that collects audio bytes from [AudioRecorder.audioBytes] SharedFlow
83+ * and buffers them in a channel. Cancelled separately from [recorderJob] to stop
84+ * collecting while still allowing the consumer to drain buffered data.
85+ */
86+ private var producerJob: Job ? = null
87+
88+ /* *
89+ * Signals when the server is ready to receive audio data. Completed with the
90+ * binary handler ID when [handleSttStart] is called. The consumer coroutine
91+ * awaits this before forwarding buffered audio to the server.
92+ */
93+ private var sttReady: CompletableDeferred <Int >? = null
6994 protected val hasMicrophone = app.packageManager.hasSystemFeature(PackageManager .FEATURE_MICROPHONE )
7095 protected var hasPermission = false
7196
72- private var binaryHandlerId: Int? = null
97+ @VisibleForTesting var binaryHandlerId: Int? = null
7398 private var conversationId: String? = null
7499 private var continueConversation = AtomicBoolean (false )
75100
@@ -183,14 +208,8 @@ abstract class AssistViewModelBase(
183208 }
184209
185210 private fun handleSttStart () {
186- viewModelScope.launch {
187- binaryHandlerId?.let { id ->
188- // Manually loop here to avoid the queue being reset too soon
189- recorderQueue?.forEach { data ->
190- serverManager.webSocketRepository(selectedServerId).sendVoiceData(id, data)
191- }
192- }
193- recorderQueue = null
211+ binaryHandlerId?.let { id ->
212+ sttReady?.complete(id)
194213 }
195214 }
196215
@@ -243,20 +262,45 @@ abstract class AssistViewModelBase(
243262 return true
244263 }
245264
246- protected fun setupRecorderQueue () {
247- recorderQueue = mutableListOf ()
265+ /* *
266+ * Sets up audio recording and buffering for voice input.
267+ *
268+ * Must be called before [runAssistPipelineInternal] for voice pipelines.
269+ * Audio data is buffered until the server is ready to receive, then all
270+ * buffered and subsequent audio is forwarded.
271+ *
272+ * Note: [AudioRecorder.audioBytes] is a SharedFlow which never completes, so we use
273+ * explicit channel management to allow graceful shutdown with buffer draining.
274+ */
275+ @VisibleForTesting(otherwise = PROTECTED )
276+ fun setupRecorder () {
277+ Timber .d(" Setting up recorder" )
278+ sttReady = CompletableDeferred ()
279+
248280 recorderJob = viewModelScope.launch {
249- audioRecorder.audioBytes.collect {
250- recorderQueue?.add(it) ? : sendVoiceData(it)
281+ val audioChannel = Channel <ByteArray >(Channel .UNLIMITED )
282+
283+ // Producer: collect from SharedFlow and buffer in channel
284+ producerJob = launch {
285+ audioRecorder.audioBytes.collect { data ->
286+ audioChannel.send(data)
287+ }
288+ }.apply {
289+ invokeOnCompletion {
290+ audioChannel.close()
291+ }
251292 }
252- }
253- }
254293
255- private fun sendVoiceData (data : ByteArray ) {
256- binaryHandlerId?.let {
257- viewModelScope.launch {
258- // Launch to prevent blocking the output flow if the network is slow
259- serverManager.webSocketRepository(selectedServerId).sendVoiceData(it, data)
294+ // Consumer: wait for STT to be ready, then forward all buffered and new data
295+ val handlerId = sttReady?.await() ? : run {
296+ FailFast .fail { " sttReady not set" }
297+ producerJob?.cancel()
298+ audioChannel.close()
299+ return @launch
300+ }
301+
302+ for (data in audioChannel) {
303+ serverManager.webSocketRepository(selectedServerId).sendVoiceData(handlerId, data)
260304 }
261305 }
262306 }
@@ -276,23 +320,42 @@ abstract class AssistViewModelBase(
276320 }
277321
278322 protected fun stopRecording (sendRecorded : Boolean = true) {
323+ stopAudioCapture()
324+
325+ binaryHandlerId?.let { handlerId ->
326+ finalizeRecording(handlerId, sendRecorded)
327+ } ? : clearRecorderState()
328+
329+ updateInputModeAfterRecording()
330+ }
331+
332+ private fun stopAudioCapture () {
279333 audioRecorder.stopRecording()
280- recorderJob ?.cancel()
281- recorderJob = null
282- if (binaryHandlerId != null ) {
283- viewModelScope.launch {
284- if ( sendRecorded) {
285- recorderQueue?.forEach {
286- sendVoiceData(it)
287- }
288- sendVoiceData(byteArrayOf()) // Empty message to indicate end of recording
289- }
290- recorderQueue = null
291- binaryHandlerId = null
334+ producerJob ?.cancel()
335+ producerJob = null
336+ }
337+
338+ private fun finalizeRecording ( handlerId : Int , sendRecorded : Boolean ) {
339+ viewModelScope.launch {
340+ if (sendRecorded) {
341+ recorderJob?.join()
342+ serverManager.webSocketRepository(selectedServerId).sendVoiceData(
343+ handlerId,
344+ byteArrayOf(),
345+ )
292346 }
293- } else {
294- recorderQueue = null
347+ clearRecorderState()
295348 }
349+ }
350+
351+ private fun clearRecorderState () {
352+ recorderJob?.cancel()
353+ recorderJob = null
354+ sttReady = null
355+ binaryHandlerId = null
356+ }
357+
358+ private fun updateInputModeAfterRecording () {
296359 if (getInput() == AssistInputMode .VOICE_ACTIVE ) {
297360 setInput(if (recorderProactive) AssistInputMode .BLOCKED else AssistInputMode .VOICE_INACTIVE )
298361 }
0 commit comments