diff --git a/packages/compass-assistant/src/components/assistant-chat.spec.tsx b/packages/compass-assistant/src/components/assistant-chat.spec.tsx index deb3c5cfb77..251b94cb69e 100644 --- a/packages/compass-assistant/src/components/assistant-chat.spec.tsx +++ b/packages/compass-assistant/src/components/assistant-chat.spec.tsx @@ -248,9 +248,8 @@ describe('AssistantChat', function () { userEvent.type(inputField, 'What is aggregation?'); userEvent.click(sendButton); - expect(ensureOptInAndSendStub.called).to.be.true; - await waitFor(() => { + expect(ensureOptInAndSendStub.called).to.be.true; expect(track).to.have.been.calledWith('Assistant Prompt Submitted', { user_input_length: 'What is aggregation?'.length, }); @@ -281,9 +280,8 @@ describe('AssistantChat', function () { userEvent.type(inputField, ' What is sharding? '); userEvent.click(screen.getByLabelText('Send message')); - expect(ensureOptInAndSendStub.called).to.be.true; - await waitFor(() => { + expect(ensureOptInAndSendStub.called).to.be.true; expect(track).to.have.been.calledWith('Assistant Prompt Submitted', { user_input_length: 'What is sharding?'.length, }); diff --git a/packages/compass-assistant/src/components/assistant-chat.tsx b/packages/compass-assistant/src/components/assistant-chat.tsx index b3fa99f467e..7a5175e54cb 100644 --- a/packages/compass-assistant/src/components/assistant-chat.tsx +++ b/packages/compass-assistant/src/components/assistant-chat.tsx @@ -1,4 +1,4 @@ -import React, { useCallback, useEffect, useContext } from 'react'; +import React, { useCallback, useEffect, useContext, useRef } from 'react'; import type { AssistantMessage } from '../compass-assistant-provider'; import { AssistantActionsContext } from '../compass-assistant-provider'; import type { Chat } from '../@ai-sdk/react/chat-react'; @@ -203,6 +203,10 @@ export const AssistantChat: React.FunctionComponent = ({ }) => { const track = useTelemetry(); const darkMode = useDarkMode(); + const messagesContainerRef = useRef(null); + const previousLastMessageId = useRef(undefined); + const { id: lastMessageId, role: lastMessageRole } = + chat.messages[chat.messages.length - 1] ?? {}; const { ensureOptInAndSend } = useContext(AssistantActionsContext); const { messages, status, error, clearError, setMessages } = useChat({ @@ -214,6 +218,26 @@ export const AssistantChat: React.FunctionComponent = ({ }, }); + const scrollToBottom = useCallback(() => { + if (messagesContainerRef.current) { + // Since the container uses flexDirection: 'column-reverse', + // scrolling to the bottom means setting scrollTop to 0 + messagesContainerRef.current.scrollTop = 0; + } + }, []); + + useEffect(() => { + if ( + lastMessageId && + previousLastMessageId.current !== undefined && + lastMessageId !== previousLastMessageId.current && + lastMessageRole === 'user' + ) { + scrollToBottom(); + } + previousLastMessageId.current = lastMessageId; + }, [lastMessageId, lastMessageRole, scrollToBottom]); + useEffect(() => { const hasExistingNonGenuineWarning = chat.messages.some( (message) => message.id === 'non-genuine-warning' @@ -232,9 +256,10 @@ export const AssistantChat: React.FunctionComponent = ({ }, [hasNonGenuineConnections, chat, setMessages]); const handleMessageSend = useCallback( - (messageBody: string) => { + async (messageBody: string) => { const trimmedMessageBody = messageBody.trim(); if (trimmedMessageBody) { + await chat.stop(); void ensureOptInAndSend?.({ text: trimmedMessageBody }, {}, () => { track('Assistant Prompt Submitted', { user_input_length: trimmedMessageBody.length, @@ -242,7 +267,7 @@ export const AssistantChat: React.FunctionComponent = ({ }); } }, - [track, ensureOptInAndSend] + [track, ensureOptInAndSend, chat] ); const handleFeedback = useCallback( @@ -343,6 +368,7 @@ export const AssistantChat: React.FunctionComponent = ({
{messages.map((message, index) => { @@ -436,7 +462,9 @@ export const AssistantChat: React.FunctionComponent = ({
+ void handleMessageSend(messageBody) + } state={status === 'submitted' ? 'loading' : undefined} textareaProps={{ placeholder: 'Ask a question',