diff --git a/src/components/ChatWindow/ChatContainer/index.jsx b/src/components/ChatWindow/ChatContainer/index.jsx index b74ac15..3fb3862 100644 --- a/src/components/ChatWindow/ChatContainer/index.jsx +++ b/src/components/ChatWindow/ChatContainer/index.jsx @@ -1,8 +1,14 @@ import React, { useState, useEffect } from "react"; import ChatHistory from "./ChatHistory"; import PromptInput from "./PromptInput"; -import handleChat from "@/utils/chat"; +import handleChat, { ABORT_STREAM_EVENT } from "@/utils/chat"; import ChatService from "@/models/chatService"; +import handleSocketResponse, { + websocketURI, + AGENT_SESSION_END, + AGENT_SESSION_START, +} from "@/utils/agent"; +import { v4 } from "uuid"; export const SEND_TEXT_EVENT = "anythingllm-embed-send-prompt"; export default function ChatContainer({ @@ -13,6 +19,8 @@ export default function ChatContainer({ const [message, setMessage] = useState(""); const [loadingResponse, setLoadingResponse] = useState(false); const [chatHistory, setChatHistory] = useState(knownHistory); + const [socketId, setSocketId] = useState(null); + const [websocket, setWebsocket] = useState(null); // Resync history if the ref to known history changes // eg: cleared. @@ -93,6 +101,18 @@ export default function ChatContainer({ const remHistory = chatHistory.length > 0 ? chatHistory.slice(0, -1) : []; var _chatHistory = [...remHistory]; + // Override hook for new messages to now go to agents until the connection closes + if (!!websocket) { + if (!promptMessage || !promptMessage?.userMessage) return false; + websocket.send( + JSON.stringify({ + type: "awaitingFeedback", + feedback: promptMessage?.userMessage, + }) + ); + return; + } + if (!promptMessage || !promptMessage?.userMessage) { setLoadingResponse(false); return false; @@ -108,14 +128,15 @@ export default function ChatContainer({ setLoadingResponse, setChatHistory, remHistory, - _chatHistory + _chatHistory, + setSocketId ) ); return; } loadingResponse === true && fetchReply(); - }, [loadingResponse, chatHistory]); + }, [loadingResponse, chatHistory, websocket]); const handleAutofillEvent = (event) => { if (!event.detail.command) return; @@ -129,6 +150,65 @@ export default function ChatContainer({ }; }, []); + // Websocket connection management for agent sessions + useEffect(() => { + function handleWSS() { + try { + if (!socketId || !!websocket) return; + const socket = new WebSocket( + `${websocketURI(settings)}/api/agent-invocation/${socketId}` + ); + + window.addEventListener(ABORT_STREAM_EVENT, () => { + window.dispatchEvent(new CustomEvent(AGENT_SESSION_END)); + if (websocket) websocket.close(); + }); + + socket.addEventListener("message", (event) => { + setLoadingResponse(true); + try { + handleSocketResponse(event, setChatHistory); + } catch (e) { + console.error("Failed to parse agent data:", e); + window.dispatchEvent(new CustomEvent(AGENT_SESSION_END)); + socket.close(); + } + setLoadingResponse(false); + }); + + socket.addEventListener("close", (_event) => { + window.dispatchEvent(new CustomEvent(AGENT_SESSION_END)); + setLoadingResponse(false); + setWebsocket(null); + setSocketId(null); + }); + + setWebsocket(socket); + window.dispatchEvent(new CustomEvent(AGENT_SESSION_START)); + } catch (e) { + setChatHistory((prev) => [ + ...prev.filter((msg) => !!msg.content), + { + uuid: v4(), + type: "abort", + content: e.message, + role: "assistant", + sources: [], + closed: true, + error: e.message, + animate: false, + pending: false, + sentAt: Math.floor(Date.now() / 1000), + }, + ]); + setLoadingResponse(false); + setWebsocket(null); + setSocketId(null); + } + } + handleWSS(); + }, [socketId]); + return (