Skip to content

Commit c61d4f9

Browse files
feat: add reasoning model (vercel#750)
Co-authored-by: Matt Apperson <[email protected]>
1 parent 7680426 commit c61d4f9

File tree

19 files changed

+335
-202
lines changed

19 files changed

+335
-202
lines changed

.env.example

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
# Get your OpenAI API Key here: https://platform.openai.com/account/api-keys
1+
# Get your OpenAI API Key here for chat models: https://platform.openai.com/account/api-keys
22
OPENAI_API_KEY=****
33

4+
# Get your Fireworks AI API Key here for reasoning models: https://fireworks.ai/account/api-keys
5+
FIREWORKS_API_KEY=****
6+
47
# Generate a random secret: https://generate-secret.vercel.app/32 or `openssl rand -base64 32`
58
AUTH_SECRET=****
69

app/(chat)/actions.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
'use server';
22

3-
import { type CoreUserMessage, generateText, Message } from 'ai';
3+
import { generateText, Message } from 'ai';
44
import { cookies } from 'next/headers';
55

6-
import { customModel } from '@/lib/ai';
76
import {
87
deleteMessagesByChatIdAfterTimestamp,
98
getMessageById,
109
updateChatVisiblityById,
1110
} from '@/lib/db/queries';
1211
import { VisibilityType } from '@/components/visibility-selector';
12+
import { myProvider } from '@/lib/ai/models';
1313

14-
export async function saveModelId(model: string) {
14+
export async function saveChatModelAsCookie(model: string) {
1515
const cookieStore = await cookies();
16-
cookieStore.set('model-id', model);
16+
cookieStore.set('chat-model', model);
1717
}
1818

1919
export async function generateTitleFromUserMessage({
@@ -22,7 +22,7 @@ export async function generateTitleFromUserMessage({
2222
message: Message;
2323
}) {
2424
const { text: title } = await generateText({
25-
model: customModel('gpt-4o-mini'),
25+
model: myProvider.languageModel('title-model'),
2626
system: `\n
2727
- you will generate a short title based on the first message a user begins a conversation with
2828
- ensure it is not more than 80 characters long

app/(chat)/api/chat/route.ts

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ import {
33
createDataStreamResponse,
44
smoothStream,
55
streamText,
6+
wrapLanguageModel,
67
} from 'ai';
78

89
import { auth } from '@/app/(auth)/auth';
9-
import { customModel } from '@/lib/ai';
10-
import { models } from '@/lib/ai/models';
10+
import { myProvider } from '@/lib/ai/models';
1111
import { systemPrompt } from '@/lib/ai/prompts';
1212
import {
1313
deleteChatById,
@@ -48,8 +48,8 @@ export async function POST(request: Request) {
4848
const {
4949
id,
5050
messages,
51-
modelId,
52-
}: { id: string; messages: Array<Message>; modelId: string } =
51+
selectedChatModel,
52+
}: { id: string; messages: Array<Message>; selectedChatModel: string } =
5353
await request.json();
5454

5555
const session = await auth();
@@ -58,12 +58,6 @@ export async function POST(request: Request) {
5858
return new Response('Unauthorized', { status: 401 });
5959
}
6060

61-
const model = models.find((model) => model.id === modelId);
62-
63-
if (!model) {
64-
return new Response('Model not found', { status: 404 });
65-
}
66-
6761
const userMessage = getMostRecentUserMessage(messages);
6862

6963
if (!userMessage) {
@@ -84,7 +78,7 @@ export async function POST(request: Request) {
8478
return createDataStreamResponse({
8579
execute: (dataStream) => {
8680
const result = streamText({
87-
model: customModel(model.apiIdentifier),
81+
model: myProvider.languageModel(selectedChatModel),
8882
system: systemPrompt,
8983
messages,
9084
maxSteps: 5,
@@ -93,32 +87,31 @@ export async function POST(request: Request) {
9387
experimental_generateMessageId: generateUUID,
9488
tools: {
9589
getWeather,
96-
createDocument: createDocument({ session, dataStream, model }),
97-
updateDocument: updateDocument({ session, dataStream, model }),
90+
createDocument: createDocument({ session, dataStream }),
91+
updateDocument: updateDocument({ session, dataStream }),
9892
requestSuggestions: requestSuggestions({
9993
session,
10094
dataStream,
101-
model,
10295
}),
10396
},
104-
onFinish: async ({ response }) => {
97+
onFinish: async ({ response, reasoning }) => {
10598
if (session.user?.id) {
10699
try {
107-
const responseMessagesWithoutIncompleteToolCalls =
108-
sanitizeResponseMessages(response.messages);
100+
const sanitizedResponseMessages = sanitizeResponseMessages({
101+
messages: response.messages,
102+
reasoning,
103+
});
109104

110105
await saveMessages({
111-
messages: responseMessagesWithoutIncompleteToolCalls.map(
112-
(message) => {
113-
return {
114-
id: message.id,
115-
chatId: id,
116-
role: message.role,
117-
content: message.content,
118-
createdAt: new Date(),
119-
};
120-
},
121-
),
106+
messages: sanitizedResponseMessages.map((message) => {
107+
return {
108+
id: message.id,
109+
chatId: id,
110+
role: message.role,
111+
content: message.content,
112+
createdAt: new Date(),
113+
};
114+
}),
122115
});
123116
} catch (error) {
124117
console.error('Failed to save chat');
@@ -131,7 +124,12 @@ export async function POST(request: Request) {
131124
},
132125
});
133126

134-
result.mergeIntoDataStream(dataStream);
127+
result.mergeIntoDataStream(dataStream, {
128+
sendReasoning: true,
129+
});
130+
},
131+
onError: (error) => {
132+
return 'Oops, an error occured!';
135133
},
136134
});
137135
}

app/(chat)/chat/[id]/page.tsx

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ import { notFound } from 'next/navigation';
33

44
import { auth } from '@/app/(auth)/auth';
55
import { Chat } from '@/components/chat';
6-
import { DEFAULT_MODEL_NAME, models } from '@/lib/ai/models';
76
import { getChatById, getMessagesByChatId } from '@/lib/db/queries';
87
import { convertToUIMessages } from '@/lib/utils';
98
import { DataStreamHandler } from '@/components/data-stream-handler';
9+
import { DEFAULT_CHAT_MODEL } from '@/lib/ai/models';
1010

1111
export default async function Page(props: { params: Promise<{ id: string }> }) {
1212
const params = await props.params;
@@ -34,17 +34,29 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
3434
});
3535

3636
const cookieStore = await cookies();
37-
const modelIdFromCookie = cookieStore.get('model-id')?.value;
38-
const selectedModelId =
39-
models.find((model) => model.id === modelIdFromCookie)?.id ||
40-
DEFAULT_MODEL_NAME;
37+
const chatModelFromCookie = cookieStore.get('chat-model');
38+
39+
if (!chatModelFromCookie) {
40+
return (
41+
<>
42+
<Chat
43+
id={chat.id}
44+
initialMessages={convertToUIMessages(messagesFromDb)}
45+
selectedChatModel={DEFAULT_CHAT_MODEL}
46+
selectedVisibilityType={chat.visibility}
47+
isReadonly={session?.user?.id !== chat.userId}
48+
/>
49+
<DataStreamHandler id={id} />
50+
</>
51+
);
52+
}
4153

4254
return (
4355
<>
4456
<Chat
4557
id={chat.id}
4658
initialMessages={convertToUIMessages(messagesFromDb)}
47-
selectedModelId={selectedModelId}
59+
selectedChatModel={chatModelFromCookie.value}
4860
selectedVisibilityType={chat.visibility}
4961
isReadonly={session?.user?.id !== chat.userId}
5062
/>

app/(chat)/page.tsx

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
11
import { cookies } from 'next/headers';
22

33
import { Chat } from '@/components/chat';
4-
import { DEFAULT_MODEL_NAME, models } from '@/lib/ai/models';
4+
import { DEFAULT_CHAT_MODEL } from '@/lib/ai/models';
55
import { generateUUID } from '@/lib/utils';
66
import { DataStreamHandler } from '@/components/data-stream-handler';
77

88
export default async function Page() {
99
const id = generateUUID();
1010

1111
const cookieStore = await cookies();
12-
const modelIdFromCookie = cookieStore.get('model-id')?.value;
12+
const modelIdFromCookie = cookieStore.get('chat-model');
1313

14-
const selectedModelId =
15-
models.find((model) => model.id === modelIdFromCookie)?.id ||
16-
DEFAULT_MODEL_NAME;
14+
if (!modelIdFromCookie) {
15+
return (
16+
<>
17+
<Chat
18+
key={id}
19+
id={id}
20+
initialMessages={[]}
21+
selectedChatModel={DEFAULT_CHAT_MODEL}
22+
selectedVisibilityType="private"
23+
isReadonly={false}
24+
/>
25+
<DataStreamHandler id={id} />
26+
</>
27+
);
28+
}
1729

1830
return (
1931
<>
2032
<Chat
2133
key={id}
2234
id={id}
2335
initialMessages={[]}
24-
selectedModelId={selectedModelId}
36+
selectedChatModel={modelIdFromCookie.value}
2537
selectedVisibilityType="private"
2638
isReadonly={false}
2739
/>

components/chat.tsx

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ import { MultimodalInput } from './multimodal-input';
1414
import { Messages } from './messages';
1515
import { VisibilityType } from './visibility-selector';
1616
import { useBlockSelector } from '@/hooks/use-block';
17+
import { toast } from 'sonner';
1718

1819
export function Chat({
1920
id,
2021
initialMessages,
21-
selectedModelId,
22+
selectedChatModel,
2223
selectedVisibilityType,
2324
isReadonly,
2425
}: {
2526
id: string;
2627
initialMessages: Array<Message>;
27-
selectedModelId: string;
28+
selectedChatModel: string;
2829
selectedVisibilityType: VisibilityType;
2930
isReadonly: boolean;
3031
}) {
@@ -42,14 +43,18 @@ export function Chat({
4243
reload,
4344
} = useChat({
4445
id,
45-
body: { id, modelId: selectedModelId },
46+
body: { id, selectedChatModel: selectedChatModel },
4647
initialMessages,
4748
experimental_throttle: 100,
4849
sendExtraMessageFields: true,
4950
generateId: generateUUID,
5051
onFinish: () => {
5152
mutate('/api/history');
5253
},
54+
onError: (error) => {
55+
console.log(error);
56+
toast.error('An error occured, please try again!');
57+
},
5358
});
5459

5560
const { data: votes } = useSWR<Array<Vote>>(
@@ -65,7 +70,7 @@ export function Chat({
6570
<div className="flex flex-col min-w-0 h-dvh bg-background">
6671
<ChatHeader
6772
chatId={id}
68-
selectedModelId={selectedModelId}
73+
selectedModelId={selectedChatModel}
6974
selectedVisibilityType={selectedVisibilityType}
7075
isReadonly={isReadonly}
7176
/>

components/markdown.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import Link from 'next/link';
2-
import React, { memo, useMemo, useState } from 'react';
2+
import React, { memo } from 'react';
33
import ReactMarkdown, { type Components } from 'react-markdown';
44
import remarkGfm from 'remark-gfm';
55
import { CodeBlock } from './code-block';

components/message-reasoning.tsx

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
'use client';
2+
3+
import { useState } from 'react';
4+
import { ChevronDownIcon, LoaderIcon } from './icons';
5+
import { motion, AnimatePresence } from 'framer-motion';
6+
import { Markdown } from './markdown';
7+
8+
interface MessageReasoningProps {
9+
isLoading: boolean;
10+
reasoning: string;
11+
}
12+
13+
export function MessageReasoning({
14+
isLoading,
15+
reasoning,
16+
}: MessageReasoningProps) {
17+
const [isExpanded, setIsExpanded] = useState(true);
18+
19+
const variants = {
20+
collapsed: {
21+
height: 0,
22+
opacity: 0,
23+
marginTop: 0,
24+
marginBottom: 0,
25+
},
26+
expanded: {
27+
height: 'auto',
28+
opacity: 1,
29+
marginTop: '1rem',
30+
marginBottom: '0.5rem',
31+
},
32+
};
33+
34+
return (
35+
<div className="flex flex-col">
36+
{isLoading ? (
37+
<div className="flex flex-row gap-2 items-center">
38+
<div className="font-medium">Reasoning</div>
39+
<div className="animate-spin">
40+
<LoaderIcon />
41+
</div>
42+
</div>
43+
) : (
44+
<div className="flex flex-row gap-2 items-center">
45+
<div className="font-medium">Reasoned for a few seconds</div>
46+
<div
47+
className="cursor-pointer"
48+
onClick={() => {
49+
setIsExpanded(!isExpanded);
50+
}}
51+
>
52+
<ChevronDownIcon />
53+
</div>
54+
</div>
55+
)}
56+
57+
<AnimatePresence initial={false}>
58+
{isExpanded && (
59+
<motion.div
60+
key="content"
61+
initial="collapsed"
62+
animate="expanded"
63+
exit="collapsed"
64+
variants={variants}
65+
transition={{ duration: 0.2, ease: 'easeInOut' }}
66+
style={{ overflow: 'hidden' }}
67+
className="pl-4 text-zinc-600 dark:text-zinc-400 border-l flex flex-col gap-4"
68+
>
69+
<Markdown>{reasoning}</Markdown>
70+
</motion.div>
71+
)}
72+
</AnimatePresence>
73+
</div>
74+
);
75+
}

0 commit comments

Comments
 (0)