diff --git a/generativeai-android-sample/app/src/main/kotlin/com/google/ai/sample/feature/chat/ChatUiState.kt b/generativeai-android-sample/app/src/main/kotlin/com/google/ai/sample/feature/chat/ChatUiState.kt index 48dea90e..e8617d6e 100644 --- a/generativeai-android-sample/app/src/main/kotlin/com/google/ai/sample/feature/chat/ChatUiState.kt +++ b/generativeai-android-sample/app/src/main/kotlin/com/google/ai/sample/feature/chat/ChatUiState.kt @@ -36,4 +36,17 @@ class ChatUiState( _messages.add(newMessage) } } + + fun addOrUpdate(msg: ChatMessage) { + if(!updateMessage(msg.id, msg)) addMessage(msg) + } + + fun updateMessage(id: String, newMessage: ChatMessage): Boolean { + val index = _messages.indexOfFirst { it.id == id } + if (index != -1) { + _messages[index] = newMessage + } + + return index != -1 + } } diff --git a/generativeai-android-sample/app/src/main/kotlin/com/google/ai/sample/feature/chat/ChatViewModel.kt b/generativeai-android-sample/app/src/main/kotlin/com/google/ai/sample/feature/chat/ChatViewModel.kt index f102678d..b2830eab 100644 --- a/generativeai-android-sample/app/src/main/kotlin/com/google/ai/sample/feature/chat/ChatViewModel.kt +++ b/generativeai-android-sample/app/src/main/kotlin/com/google/ai/sample/feature/chat/ChatViewModel.kt @@ -24,7 +24,13 @@ import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.collectIndexed +import kotlinx.coroutines.flow.filter +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.flow.runningFold +import kotlinx.coroutines.flow.skip import kotlinx.coroutines.launch +import java.util.UUID class ChatViewModel( generativeModel: GenerativeModel @@ -61,16 +67,18 @@ class ChatViewModel( viewModelScope.launch { try { - val response = chat.sendMessage(userMessage) - - _uiState.value.replaceLastPendingMessage() - - response.text?.let { modelResponse -> - _uiState.value.addMessage( + val uuid = UUID.randomUUID().toString() + chat.sendMessageStream(userMessage).runningFold("") { message, response -> + message + response.text + }.filter { it.isNotEmpty() }.collectIndexed { index, value -> + if (index == 0) { + _uiState.value.replaceLastPendingMessage() + } + _uiState.value.addOrUpdate( ChatMessage( - text = modelResponse, - participant = Participant.MODEL, - isPending = false + id = uuid, + text = value, + participant = Participant.MODEL ) ) }