From 2390a2204fb674eda418d7e3aad8d75fd0bbff59 Mon Sep 17 00:00:00 2001 From: Shekhar Tyagi Date: Sun, 13 Jul 2025 20:34:05 +0530 Subject: [PATCH 1/4] feat: implement branching functionality for chat threads Added a new action to branch out from an existing chat thread, creating a new thread with the selected message and linking it to the parent thread. Updated UI components to support this feature, including the addition of a GitBranch icon for visual representation of parent threads. --- .prettierrc | 7 + src/app/(chat)/project/[id]/page.tsx | 7 +- src/app/api/chat/actions.ts | 71 +- src/app/api/chat/branch/route.ts | 14 + src/components/layouts/app-header.tsx | 4 + .../layouts/app-sidebar-threads.tsx | 11 +- src/components/message-parts.tsx | 38 + .../db/migrations/pg/0008_amazing_chat.sql | 1 + .../db/migrations/pg/meta/0008_snapshot.json | 1016 +++++++++++++++++ src/lib/db/migrations/pg/meta/_journal.json | 7 + .../db/pg/repositories/chat-repository.pg.ts | 4 + src/lib/db/pg/schema.pg.ts | 1 + src/types/chat.ts | 1 + 13 files changed, 1177 insertions(+), 5 deletions(-) create mode 100755 .prettierrc create mode 100644 src/app/api/chat/branch/route.ts create mode 100644 src/lib/db/migrations/pg/0008_amazing_chat.sql create mode 100644 src/lib/db/migrations/pg/meta/0008_snapshot.json diff --git a/.prettierrc b/.prettierrc new file mode 100755 index 000000000..301d76857 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,7 @@ +{ + "singleQuote": false, + "tabWidth": 2, + "trailingComma": "all", + "semi": true, + "printWidth": 80 +} diff --git a/src/app/(chat)/project/[id]/page.tsx b/src/app/(chat)/project/[id]/page.tsx index 4d284af4f..c4de123b0 100644 --- a/src/app/(chat)/project/[id]/page.tsx +++ b/src/app/(chat)/project/[id]/page.tsx @@ -16,6 +16,7 @@ import { FileUp, Pencil, MessagesSquare, + GitBranch, } from "lucide-react"; import { useTranslations } from "next-intl"; import Link from "next/link"; @@ -224,8 +225,10 @@ export default function ProjectPage() { >
-
- {thread.title} +
+ {thread.parentThreadId && ( + + )} {thread.title}
diff --git a/src/app/api/chat/actions.ts b/src/app/api/chat/actions.ts index ffcff71b0..72fdaf23c 100644 --- a/src/app/api/chat/actions.ts +++ b/src/app/api/chat/actions.ts @@ -13,7 +13,12 @@ import { generateExampleToolSchemaPrompt, } from "lib/ai/prompts"; -import type { ChatModel, ChatThread, Project } from "app-types/chat"; +import type { + ChatMessage, + ChatModel, + ChatThread, + Project, +} from "app-types/chat"; import { chatRepository, @@ -31,6 +36,7 @@ import logger from "logger"; import { JSONSchema7 } from "json-schema"; import { ObjectJsonSchema7 } from "app-types/util"; import { jsonSchemaToZod } from "lib/json-schema-to-zod"; +import { randomUUID } from "crypto"; export async function getUserId() { const session = await getSession(); @@ -44,7 +50,10 @@ export async function getUserId() { export async function generateTitleFromUserMessageAction({ message, model, -}: { message: Message; model: LanguageModel }) { +}: { + message: Message; + model: LanguageModel; +}) { await getSession(); const prompt = toAny(message.parts?.at(-1))?.text || "unknown"; @@ -87,6 +96,64 @@ export async function deleteMessagesByChatIdAfterTimestampAction( await chatRepository.deleteMessagesByChatIdAfterTimestamp(messageId); } +export async function branchOutAction( + threadId: string, + messageId: string, +): Promise<{ + id: string; +}> { + const userId = await getUserId(); + console.log("userId", userId); + + if (!userId) { + throw new Error("User not found"); + } + + const threadDetails = await chatRepository.selectThreadDetails(threadId); + if (!threadDetails) { + throw new Error("Thread not found"); + } + + const isMessageInThread = threadDetails.messages.some( + (message) => message.id === messageId, + ); + if (!isMessageInThread) { + throw new Error("Message not found in thread"); + } + + const messagesForNewThread: ChatMessage[] = []; + + for (const message of threadDetails.messages) { + if (message.id === messageId) { + messagesForNewThread.push(message); + break; + } + messagesForNewThread.push(message); + } + + const newThread = await chatRepository.insertThread({ + title: `Branch - ${threadDetails.title}`, + userId: threadDetails.userId, + projectId: threadDetails.projectId, + id: randomUUID(), + parentThreadId: threadDetails.id, + }); + + await chatRepository.insertMessages( + messagesForNewThread.map((message) => ({ + role: message.role, + parts: message.parts, + model: message.model, + attachments: message.attachments, + annotations: message.annotations, + threadId: newThread.id, + id: randomUUID(), + })), + ); + + return { id: newThread.id }; +} + export async function selectThreadListByUserIdAction() { const userId = await getUserId(); const threads = await chatRepository.selectThreadsByUserId(userId); diff --git a/src/app/api/chat/branch/route.ts b/src/app/api/chat/branch/route.ts new file mode 100644 index 000000000..57eac12f4 --- /dev/null +++ b/src/app/api/chat/branch/route.ts @@ -0,0 +1,14 @@ +import { NextRequest } from "next/server"; + +export async function POST(_request: NextRequest) { + try { + } catch (error: any) { + console.error("Error:", error); + return new Response( + JSON.stringify({ error: error.message || "Internal Server Error" }), + { + status: 500, + }, + ); + } +} diff --git a/src/components/layouts/app-header.tsx b/src/components/layouts/app-header.tsx index 385a9195c..56e08d631 100644 --- a/src/components/layouts/app-header.tsx +++ b/src/components/layouts/app-header.tsx @@ -9,6 +9,7 @@ import { ChevronRight, MessageCircleDashed, PanelLeft, + GitBranch, } from "lucide-react"; import { Button } from "ui/button"; import { Separator } from "ui/separator"; @@ -187,6 +188,9 @@ function ThreadDropdownComponent() { variant="ghost" className="hover:text-foreground cursor-pointer flex gap-1 items-center px-2 py-1 rounded-md hover:bg-accent" > + {currentThread.parentThreadId && ( + + )}

{currentThread.title}

diff --git a/src/components/layouts/app-sidebar-threads.tsx b/src/components/layouts/app-sidebar-threads.tsx index db19c3bb7..7830c2de7 100644 --- a/src/components/layouts/app-sidebar-threads.tsx +++ b/src/components/layouts/app-sidebar-threads.tsx @@ -11,7 +11,13 @@ import { import { SidebarGroupContent, SidebarMenu, SidebarMenuItem } from "ui/sidebar"; import { SidebarGroup } from "ui/sidebar"; import { ThreadDropdown } from "../thread-dropdown"; -import { ChevronDown, ChevronUp, MoreHorizontal, Trash } from "lucide-react"; +import { + ChevronDown, + ChevronUp, + MoreHorizontal, + Trash, + GitBranch, +} from "lucide-react"; import { useMounted } from "@/hooks/use-mounted"; import { appStore } from "@/app/store"; import { Button } from "ui/button"; @@ -234,6 +240,9 @@ export function AppSidebarThreads() { href={`/chat/${thread.id}`} className="flex items-center" > + {thread.parentThreadId && ( + + )}

{thread.title}

diff --git a/src/components/message-parts.tsx b/src/components/message-parts.tsx index a8e091f19..a961c6900 100644 --- a/src/components/message-parts.tsx +++ b/src/components/message-parts.tsx @@ -17,6 +17,7 @@ import { Loader2, AlertTriangleIcon, Percent, + GitBranch, } from "lucide-react"; import { Tooltip, TooltipContent, TooltipTrigger } from "ui/tooltip"; import { Button } from "ui/button"; @@ -39,6 +40,7 @@ import { useCopy } from "@/hooks/use-copy"; import { AnimatePresence, motion } from "framer-motion"; import { SelectModel } from "./select-model"; import { + branchOutAction, deleteMessageAction, deleteMessagesByChatIdAfterTimestampAction, } from "@/app/api/chat/actions"; @@ -80,6 +82,7 @@ import { TavilyResponse } from "lib/ai/tools/web/web-search"; import { CodeBlock } from "ui/CodeBlock"; import { SafeJsExecutionResult, safeJsRun } from "lib/safe-js-run"; +import { useRouter } from "next/navigation"; type MessagePart = UIMessage["parts"][number]; @@ -274,6 +277,9 @@ export const AssistMessagePart = memo(function AssistMessagePart({ const { copied, copy } = useCopy(); const [isLoading, setIsLoading] = useState(false); const [isDeleting, setIsDeleting] = useState(false); + const [isBranching, setIsBranching] = useState(false); + + const router = useRouter(); const deleteMessage = useCallback(() => { safe(() => setIsDeleting(true)) @@ -322,6 +328,20 @@ export const AssistMessagePart = memo(function AssistMessagePart({ .unwrap(); }; + const handleBranchOut = useCallback(async () => { + safe(() => setIsBranching(true)) + .ifOk(async () => { + if (!threadId) { + throw new Error("Thread ID is required"); + } + const newThread = await branchOutAction(threadId, message.id); + router.push(`/chat/${newThread.id}`); + }) + .ifFail((error) => toast.error(error.message)) + .watch(() => setIsBranching(false)) + .unwrap(); + }, [message.id]); + return (
Change Model + + + + + Branch Out + - )} - -
- - {mention.name} - - {mention.description ? ( - - {mention.description} - - ) : null} -
- -
- ); - })} -
- )}