Skip to content

Commit 583e33b

Browse files
authored
add message adapter for Google Gemini, Flash (#1772)
bump up Google/Generative-AI package version to 0.11.4 simplify Gemini Vision handling google generative ai npm
1 parent d60e1f3 commit 583e33b

File tree

8 files changed

+109
-158
lines changed

8 files changed

+109
-158
lines changed

app/api/chat/google/route.ts

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,32 @@ export async function POST(request: Request) {
1919
const genAI = new GoogleGenerativeAI(profile.google_gemini_api_key || "")
2020
const googleModel = genAI.getGenerativeModel({ model: chatSettings.model })
2121

22-
if (chatSettings.model === "gemini-pro") {
23-
const lastMessage = messages.pop()
22+
const lastMessage = messages.pop()
2423

25-
const chat = googleModel.startChat({
26-
history: messages,
27-
generationConfig: {
28-
temperature: chatSettings.temperature
29-
}
30-
})
24+
const chat = googleModel.startChat({
25+
history: messages,
26+
generationConfig: {
27+
temperature: chatSettings.temperature
28+
}
29+
})
3130

32-
const response = await chat.sendMessageStream(lastMessage.parts)
31+
const response = await chat.sendMessageStream(lastMessage.parts)
3332

34-
const encoder = new TextEncoder()
35-
const readableStream = new ReadableStream({
36-
async start(controller) {
37-
for await (const chunk of response.stream) {
38-
const chunkText = chunk.text()
39-
controller.enqueue(encoder.encode(chunkText))
40-
}
41-
controller.close()
33+
const encoder = new TextEncoder()
34+
const readableStream = new ReadableStream({
35+
async start(controller) {
36+
for await (const chunk of response.stream) {
37+
const chunkText = chunk.text()
38+
controller.enqueue(encoder.encode(chunkText))
4239
}
43-
})
44-
45-
return new Response(readableStream, {
46-
headers: { "Content-Type": "text/plain" }
47-
})
48-
} else if (chatSettings.model === "gemini-pro-vision") {
49-
// FIX: Hacky until chat messages are supported
50-
const HACKY_MESSAGE = messages[messages.length - 1]
51-
52-
const result = await googleModel.generateContent([
53-
HACKY_MESSAGE.prompt,
54-
HACKY_MESSAGE.imageParts
55-
])
56-
57-
const response = result.response
40+
controller.close()
41+
}
42+
})
5843

59-
const text = response.text()
44+
return new Response(readableStream, {
45+
headers: { "Content-Type": "text/plain" }
46+
})
6047

61-
return new Response(text, {
62-
headers: { "Content-Type": "text/plain" }
63-
})
64-
}
6548
} catch (error: any) {
6649
let errorMessage = error.message || "An unexpected error occurred"
6750
const errorCode = error.status || 500

components/chat/chat-helpers/index.ts

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { createMessages, updateMessage } from "@/db/messages"
77
import { uploadMessageImage } from "@/db/storage/message-images"
88
import {
99
buildFinalMessages,
10-
buildGoogleGeminiFinalMessages
10+
adaptMessagesForGoogleGemini
1111
} from "@/lib/build-prompt"
1212
import { consumeReadableStream } from "@/lib/consume-stream"
1313
import { Tables, TablesInsert } from "@/supabase/types"
@@ -206,16 +206,13 @@ export const handleHostedChat = async (
206206
? "azure"
207207
: modelData.provider
208208

209-
let formattedMessages = []
209+
let draftMessages = await buildFinalMessages(payload, profile, chatImages)
210210

211+
let formattedMessages : any[] = []
211212
if (provider === "google") {
212-
formattedMessages = await buildGoogleGeminiFinalMessages(
213-
payload,
214-
profile,
215-
newMessageImages
216-
)
213+
formattedMessages = await adaptMessagesForGoogleGemini(payload, draftMessages)
217214
} else {
218-
formattedMessages = await buildFinalMessages(payload, profile, chatImages)
215+
formattedMessages = draftMessages
219216
}
220217

221218
const apiEndpoint =

lib/build-prompt.ts

Lines changed: 61 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { Tables } from "@/supabase/types"
22
import { ChatPayload, MessageImage } from "@/types"
33
import { encode } from "gpt-tokenizer"
4+
import { getBase64FromDataURL, getMediaTypeFromDataURL } from "@/lib/utils"
45

56
const buildBasePrompt = (
67
prompt: string,
@@ -182,125 +183,78 @@ function buildRetrievalText(fileItems: Tables<"file_items">[]) {
182183
return `You may use the following sources if needed to answer the user's question. If you don't know the answer, say "I don't know."\n\n${retrievalText}`
183184
}
184185

185-
export async function buildGoogleGeminiFinalMessages(
186-
payload: ChatPayload,
187-
profile: Tables<"profiles">,
188-
messageImageFiles: MessageImage[]
189-
) {
190-
const { chatSettings, workspaceInstructions, chatMessages, assistant } =
191-
payload
186+
function adaptSingleMessageForGoogleGemini(message: any) {
192187

193-
const BUILT_PROMPT = buildBasePrompt(
194-
chatSettings.prompt,
195-
chatSettings.includeProfileContext ? profile.profile_context || "" : "",
196-
chatSettings.includeWorkspaceInstructions ? workspaceInstructions : "",
197-
assistant
198-
)
199-
200-
let finalMessages = []
201-
202-
let usedTokens = 0
203-
const CHUNK_SIZE = chatSettings.contextLength
204-
const PROMPT_TOKENS = encode(chatSettings.prompt).length
205-
let REMAINING_TOKENS = CHUNK_SIZE - PROMPT_TOKENS
206-
207-
usedTokens += PROMPT_TOKENS
188+
let adaptedParts = []
208189

209-
for (let i = chatMessages.length - 1; i >= 0; i--) {
210-
const message = chatMessages[i].message
211-
const messageTokens = encode(message.content).length
212-
213-
if (messageTokens <= REMAINING_TOKENS) {
214-
REMAINING_TOKENS -= messageTokens
215-
usedTokens += messageTokens
216-
finalMessages.unshift(message)
217-
} else {
218-
break
219-
}
220-
}
221-
222-
let tempSystemMessage: Tables<"messages"> = {
223-
chat_id: "",
224-
assistant_id: null,
225-
content: BUILT_PROMPT,
226-
created_at: "",
227-
id: chatMessages.length + "",
228-
image_paths: [],
229-
model: payload.chatSettings.model,
230-
role: "system",
231-
sequence_number: chatMessages.length,
232-
updated_at: "",
233-
user_id: ""
190+
let rawParts = []
191+
if(!Array.isArray(message.content)) {
192+
rawParts.push({type: 'text', text: message.content})
193+
} else {
194+
rawParts = message.content
234195
}
235196

236-
finalMessages.unshift(tempSystemMessage)
197+
for(let i = 0; i < rawParts.length; i++) {
198+
let rawPart = rawParts[i]
237199

238-
let GOOGLE_FORMATTED_MESSAGES = []
239-
240-
if (chatSettings.model === "gemini-pro") {
241-
GOOGLE_FORMATTED_MESSAGES = [
242-
{
243-
role: "user",
244-
parts: finalMessages[0].content
245-
},
246-
{
247-
role: "model",
248-
parts: "I will follow your instructions."
249-
}
250-
]
251-
252-
for (let i = 1; i < finalMessages.length; i++) {
253-
GOOGLE_FORMATTED_MESSAGES.push({
254-
role: finalMessages[i].role === "user" ? "user" : "model",
255-
parts: finalMessages[i].content as string
256-
})
257-
}
258-
259-
return GOOGLE_FORMATTED_MESSAGES
260-
} else if ((chatSettings.model = "gemini-pro-vision")) {
261-
// Gemini Pro Vision doesn't currently support messages
262-
async function fileToGenerativePart(file: File) {
263-
const base64EncodedDataPromise = new Promise(resolve => {
264-
const reader = new FileReader()
265-
266-
reader.onloadend = () => {
267-
if (typeof reader.result === "string") {
268-
resolve(reader.result.split(",")[1])
269-
}
270-
}
271-
272-
reader.readAsDataURL(file)
273-
})
274-
275-
return {
200+
if(rawPart.type == 'text') {
201+
adaptedParts.push({text: rawPart.text})
202+
} else if(rawPart.type === 'image_url') {
203+
adaptedParts.push({
276204
inlineData: {
277-
data: await base64EncodedDataPromise,
278-
mimeType: file.type
205+
data: getBase64FromDataURL(rawPart.image_url.url),
206+
mimeType: getMediaTypeFromDataURL(rawPart.image_url.url),
279207
}
280-
}
208+
})
281209
}
210+
}
282211

283-
let prompt = ""
212+
let role = 'user'
213+
if(["user", "system"].includes(message.role)) {
214+
role = 'user'
215+
} else if(message.role === 'assistant') {
216+
role = 'model'
217+
}
284218

285-
for (let i = 0; i < finalMessages.length; i++) {
286-
prompt += `${finalMessages[i].role}:\n${finalMessages[i].content}\n\n`
287-
}
219+
return {
220+
role: role,
221+
parts: adaptedParts
222+
}
223+
}
288224

289-
const files = messageImageFiles.map(file => file.file)
290-
const imageParts = await Promise.all(
291-
files.map(file =>
292-
file ? fileToGenerativePart(file) : Promise.resolve(null)
293-
)
294-
)
295-
296-
// FIX: Hacky until chat messages are supported
297-
return [
298-
{
299-
prompt,
300-
imageParts
301-
}
225+
function adaptMessagesForGeminiVision(
226+
messages: any[]
227+
) {
228+
// Gemini Pro Vision cannot process multiple messages
229+
// Reformat, using all texts and last visual only
230+
231+
const basePrompt = messages[0].parts[0].text
232+
const baseRole = messages[0].role
233+
const lastMessage = messages[messages.length-1]
234+
const visualMessageParts = lastMessage.parts;
235+
let visualQueryMessages = [{
236+
role: "user",
237+
parts: [
238+
`${baseRole}:\n${basePrompt}\n\nuser:\n${visualMessageParts[0].text}\n\n`,
239+
visualMessageParts.slice(1)
302240
]
241+
}]
242+
return visualQueryMessages
243+
}
244+
245+
export async function adaptMessagesForGoogleGemini(
246+
payload: ChatPayload,
247+
messages: any[]
248+
) {
249+
let geminiMessages = []
250+
for (let i = 0; i < messages.length; i++) {
251+
let adaptedMessage = adaptSingleMessageForGoogleGemini(messages[i])
252+
geminiMessages.push(adaptedMessage)
303253
}
304254

305-
return finalMessages
255+
if(payload.chatSettings.model === "gemini-pro-vision") {
256+
geminiMessages = adaptMessagesForGeminiVision(geminiMessages)
257+
}
258+
return geminiMessages
306259
}
260+

lib/chat-setting-limits.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ export const CHAT_SETTING_LIMITS: Record<LLMID, ChatSettingLimits> = {
4141
},
4242

4343
// GOOGLE MODELS
44+
"gemini-1.5-pro-latest": {
45+
MIN_TEMPERATURE: 0.0,
46+
MAX_TEMPERATURE: 1.0,
47+
MAX_TOKEN_OUTPUT_LENGTH: 8192,
48+
MAX_CONTEXT_LENGTH: 1040384
49+
},
4450
"gemini-pro": {
4551
MIN_TEMPERATURE: 0.0,
4652
MAX_TEMPERATURE: 1.0,

lib/models/llm/google-llm-list.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@ const GOOGLE_PLATORM_LINK = "https://ai.google.dev/"
44

55
// Google Models (UPDATED 12/22/23) -----------------------------
66

7+
// Gemini Flash (UPDATED 05/28/24)
8+
const GEMINI_FLASH: LLM = {
9+
modelId: "gemini-1.5-pro-latest",
10+
modelName: "Gemini Flash",
11+
provider: "google",
12+
hostedId: "gemini-1.5-pro-latest",
13+
platformLink: GOOGLE_PLATORM_LINK,
14+
imageInput: false
15+
}
16+
717
// Gemini Pro (UPDATED 12/22/23)
818
const GEMINI_PRO: LLM = {
919
modelId: "gemini-pro",
@@ -24,4 +34,4 @@ const GEMINI_PRO_VISION: LLM = {
2434
imageInput: true
2535
}
2636

27-
export const GOOGLE_LLM_LIST: LLM[] = [GEMINI_PRO, GEMINI_PRO_VISION]
37+
export const GOOGLE_LLM_LIST: LLM[] = [GEMINI_PRO, GEMINI_PRO_VISION, GEMINI_FLASH]

package-lock.json

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"@anthropic-ai/sdk": "^0.18.0",
3030
"@apidevtools/json-schema-ref-parser": "^11.1.0",
3131
"@azure/openai": "^1.0.0-beta.8",
32-
"@google/generative-ai": "^0.1.3",
32+
"@google/generative-ai": "^0.11.4",
3333
"@hookform/resolvers": "^3.3.2",
3434
"@mistralai/mistralai": "^0.0.8",
3535
"@radix-ui/react-accordion": "^1.1.2",

types/llms.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export type OpenAILLMID =
2020
export type GoogleLLMID =
2121
| "gemini-pro" // Gemini Pro
2222
| "gemini-pro-vision" // Gemini Pro Vision
23+
| "gemini-1.5-pro-latest"
2324

2425
// Anthropic Models
2526
export type AnthropicLLMID =

0 commit comments

Comments
 (0)