diff --git a/.prettierrc b/.prettierrc new file mode 100755 index 000000000..301d76857 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,7 @@ +{ + "singleQuote": false, + "tabWidth": 2, + "trailingComma": "all", + "semi": true, + "printWidth": 80 +} diff --git a/src/app/(chat)/project/[id]/page.tsx b/src/app/(chat)/project/[id]/page.tsx index cad9d116b..786a0d7b6 100644 --- a/src/app/(chat)/project/[id]/page.tsx +++ b/src/app/(chat)/project/[id]/page.tsx @@ -16,6 +16,7 @@ import { FileUp, Pencil, MessagesSquare, + GitBranch, } from "lucide-react"; import { useTranslations } from "next-intl"; import Link from "next/link"; @@ -225,8 +226,10 @@ export default function ProjectPage() { >
-
- {thread.title} +
+ {thread.parentThreadId && ( + + )} {thread.title}
diff --git a/src/app/api/chat/actions.ts b/src/app/api/chat/actions.ts index ffcff71b0..72fdaf23c 100644 --- a/src/app/api/chat/actions.ts +++ b/src/app/api/chat/actions.ts @@ -13,7 +13,12 @@ import { generateExampleToolSchemaPrompt, } from "lib/ai/prompts"; -import type { ChatModel, ChatThread, Project } from "app-types/chat"; +import type { + ChatMessage, + ChatModel, + ChatThread, + Project, +} from "app-types/chat"; import { chatRepository, @@ -31,6 +36,7 @@ import logger from "logger"; import { JSONSchema7 } from "json-schema"; import { ObjectJsonSchema7 } from "app-types/util"; import { jsonSchemaToZod } from "lib/json-schema-to-zod"; +import { randomUUID } from "crypto"; export async function getUserId() { const session = await getSession(); @@ -44,7 +50,10 @@ export async function getUserId() { export async function generateTitleFromUserMessageAction({ message, model, -}: { message: Message; model: LanguageModel }) { +}: { + message: Message; + model: LanguageModel; +}) { await getSession(); const prompt = toAny(message.parts?.at(-1))?.text || "unknown"; @@ -87,6 +96,64 @@ export async function deleteMessagesByChatIdAfterTimestampAction( await chatRepository.deleteMessagesByChatIdAfterTimestamp(messageId); } +export async function branchOutAction( + threadId: string, + messageId: string, +): Promise<{ + id: string; +}> { + const userId = await getUserId(); + console.log("userId", userId); + + if (!userId) { + throw new Error("User not found"); + } + + const threadDetails = await chatRepository.selectThreadDetails(threadId); + if (!threadDetails) { + throw new Error("Thread not found"); + } + + const isMessageInThread = threadDetails.messages.some( + (message) => message.id === messageId, + ); + if (!isMessageInThread) { + throw new Error("Message not found in thread"); + } + + const messagesForNewThread: ChatMessage[] = []; + + for (const message of threadDetails.messages) { + if (message.id === messageId) { + messagesForNewThread.push(message); + break; + } + messagesForNewThread.push(message); + } + + const newThread = await chatRepository.insertThread({ + title: `Branch - ${threadDetails.title}`, + userId: threadDetails.userId, + projectId: threadDetails.projectId, + id: randomUUID(), + parentThreadId: threadDetails.id, + }); + + await chatRepository.insertMessages( + messagesForNewThread.map((message) => ({ + role: message.role, + parts: message.parts, + model: message.model, + attachments: message.attachments, + annotations: message.annotations, + threadId: newThread.id, + id: randomUUID(), + })), + ); + + return { id: newThread.id }; +} + export async function selectThreadListByUserIdAction() { const userId = await getUserId(); const threads = await chatRepository.selectThreadsByUserId(userId); diff --git a/src/app/api/chat/branch/route.ts b/src/app/api/chat/branch/route.ts new file mode 100644 index 000000000..57eac12f4 --- /dev/null +++ b/src/app/api/chat/branch/route.ts @@ -0,0 +1,14 @@ +import { NextRequest } from "next/server"; + +export async function POST(_request: NextRequest) { + try { + } catch (error: any) { + console.error("Error:", error); + return new Response( + JSON.stringify({ error: error.message || "Internal Server Error" }), + { + status: 500, + }, + ); + } +} diff --git a/src/components/layouts/app-header.tsx b/src/components/layouts/app-header.tsx index 12c3675e9..0b36780e7 100644 --- a/src/components/layouts/app-header.tsx +++ b/src/components/layouts/app-header.tsx @@ -9,6 +9,7 @@ import { ChevronRight, MessageCircleDashed, PanelLeft, + GitBranch, } from "lucide-react"; import { Button } from "ui/button"; import { Separator } from "ui/separator"; @@ -187,6 +188,9 @@ function ThreadDropdownComponent() { variant="ghost" className="data-[state=open]:bg-input! hover:text-foreground cursor-pointer flex gap-1 items-center px-2 py-1 rounded-md hover:bg-accent" > + {currentThread.parentThreadId && ( + + )}

{currentThread.title}

diff --git a/src/components/layouts/app-sidebar-threads.tsx b/src/components/layouts/app-sidebar-threads.tsx index db19c3bb7..7830c2de7 100644 --- a/src/components/layouts/app-sidebar-threads.tsx +++ b/src/components/layouts/app-sidebar-threads.tsx @@ -11,7 +11,13 @@ import { import { SidebarGroupContent, SidebarMenu, SidebarMenuItem } from "ui/sidebar"; import { SidebarGroup } from "ui/sidebar"; import { ThreadDropdown } from "../thread-dropdown"; -import { ChevronDown, ChevronUp, MoreHorizontal, Trash } from "lucide-react"; +import { + ChevronDown, + ChevronUp, + MoreHorizontal, + Trash, + GitBranch, +} from "lucide-react"; import { useMounted } from "@/hooks/use-mounted"; import { appStore } from "@/app/store"; import { Button } from "ui/button"; @@ -234,6 +240,9 @@ export function AppSidebarThreads() { href={`/chat/${thread.id}`} className="flex items-center" > + {thread.parentThreadId && ( + + )}

{thread.title}

diff --git a/src/components/message-parts.tsx b/src/components/message-parts.tsx index 63c2a1d51..38982096c 100644 --- a/src/components/message-parts.tsx +++ b/src/components/message-parts.tsx @@ -16,6 +16,7 @@ import { Loader2, AlertTriangleIcon, Percent, + GitBranch, HammerIcon, } from "lucide-react"; import { Tooltip, TooltipContent, TooltipTrigger } from "ui/tooltip"; @@ -39,6 +40,7 @@ import { useCopy } from "@/hooks/use-copy"; import { AnimatePresence, motion } from "framer-motion"; import { SelectModel } from "./select-model"; import { + branchOutAction, deleteMessageAction, deleteMessagesByChatIdAfterTimestampAction, } from "@/app/api/chat/actions"; @@ -75,6 +77,7 @@ import { TavilyResponse } from "lib/ai/tools/web/web-search"; import { CodeBlock } from "ui/CodeBlock"; import { SafeJsExecutionResult, safeJsRun } from "lib/safe-js-run"; +import { useRouter } from "next/navigation"; type MessagePart = UIMessage["parts"][number]; @@ -247,6 +250,9 @@ export const AssistMessagePart = memo(function AssistMessagePart({ const { copied, copy } = useCopy(); const [isLoading, setIsLoading] = useState(false); const [isDeleting, setIsDeleting] = useState(false); + const [isBranching, setIsBranching] = useState(false); + + const router = useRouter(); const deleteMessage = useCallback(() => { safe(() => setIsDeleting(true)) @@ -295,6 +301,20 @@ export const AssistMessagePart = memo(function AssistMessagePart({ .unwrap(); }; + const handleBranchOut = useCallback(async () => { + safe(() => setIsBranching(true)) + .ifOk(async () => { + if (!threadId) { + throw new Error("Thread ID is required"); + } + const newThread = await branchOutAction(threadId, message.id); + router.push(`/chat/${newThread.id}`); + }) + .ifFail((error) => toast.error(error.message)) + .watch(() => setIsBranching(false)) + .unwrap(); + }, [message.id]); + return (
Change Model + + + + + Branch Out + - )} - -
- - {mention.name} - - {mention.description ? ( - - {mention.description} - - ) : null} -
- -
- ); - })} -
- )}
`MAX(${ChatMessageSchema.createdAt})`.as( "last_message_at", ), @@ -179,6 +182,7 @@ export const pgChatRepository: ChatRepository = { title: row.title, userId: row.userId, projectId: row.projectId, + parentThreadId: row.parentThreadId, createdAt: row.createdAt, lastMessageAt: row.lastMessageAt ? new Date(row.lastMessageAt).getTime() diff --git a/src/lib/db/pg/schema.pg.ts b/src/lib/db/pg/schema.pg.ts index 8ca7a4ecb..f1409e542 100644 --- a/src/lib/db/pg/schema.pg.ts +++ b/src/lib/db/pg/schema.pg.ts @@ -23,6 +23,7 @@ export const ChatThreadSchema = pgTable("chat_thread", { .references(() => UserSchema.id), createdAt: timestamp("created_at").notNull().default(sql`CURRENT_TIMESTAMP`), projectId: uuid("project_id"), + parentThreadId: uuid("parent_thread_id"), }); export const ChatMessageSchema = pgTable("chat_message", { diff --git a/src/types/chat.ts b/src/types/chat.ts index f09549228..b7823dc7a 100644 --- a/src/types/chat.ts +++ b/src/types/chat.ts @@ -14,6 +14,7 @@ export type ChatThread = { userId: string; createdAt: Date; projectId: string | null; + parentThreadId?: string | null; }; export type Project = {