|
1 | 1 | import { Tables } from "@/supabase/types" |
2 | 2 | import { ChatPayload, MessageImage } from "@/types" |
3 | 3 | import { encode } from "gpt-tokenizer" |
| 4 | +import { getBase64FromDataURL, getMediaTypeFromDataURL } from "@/lib/utils" |
4 | 5 |
|
5 | 6 | const buildBasePrompt = ( |
6 | 7 | prompt: string, |
@@ -182,125 +183,78 @@ function buildRetrievalText(fileItems: Tables<"file_items">[]) { |
182 | 183 | return `You may use the following sources if needed to answer the user's question. If you don't know the answer, say "I don't know."\n\n${retrievalText}` |
183 | 184 | } |
184 | 185 |
|
185 | | -export async function buildGoogleGeminiFinalMessages( |
186 | | - payload: ChatPayload, |
187 | | - profile: Tables<"profiles">, |
188 | | - messageImageFiles: MessageImage[] |
189 | | -) { |
190 | | - const { chatSettings, workspaceInstructions, chatMessages, assistant } = |
191 | | - payload |
| 186 | +function adaptSingleMessageForGoogleGemini(message: any) { |
192 | 187 |
|
193 | | - const BUILT_PROMPT = buildBasePrompt( |
194 | | - chatSettings.prompt, |
195 | | - chatSettings.includeProfileContext ? profile.profile_context || "" : "", |
196 | | - chatSettings.includeWorkspaceInstructions ? workspaceInstructions : "", |
197 | | - assistant |
198 | | - ) |
199 | | - |
200 | | - let finalMessages = [] |
201 | | - |
202 | | - let usedTokens = 0 |
203 | | - const CHUNK_SIZE = chatSettings.contextLength |
204 | | - const PROMPT_TOKENS = encode(chatSettings.prompt).length |
205 | | - let REMAINING_TOKENS = CHUNK_SIZE - PROMPT_TOKENS |
206 | | - |
207 | | - usedTokens += PROMPT_TOKENS |
| 188 | + let adaptedParts = [] |
208 | 189 |
|
209 | | - for (let i = chatMessages.length - 1; i >= 0; i--) { |
210 | | - const message = chatMessages[i].message |
211 | | - const messageTokens = encode(message.content).length |
212 | | - |
213 | | - if (messageTokens <= REMAINING_TOKENS) { |
214 | | - REMAINING_TOKENS -= messageTokens |
215 | | - usedTokens += messageTokens |
216 | | - finalMessages.unshift(message) |
217 | | - } else { |
218 | | - break |
219 | | - } |
220 | | - } |
221 | | - |
222 | | - let tempSystemMessage: Tables<"messages"> = { |
223 | | - chat_id: "", |
224 | | - assistant_id: null, |
225 | | - content: BUILT_PROMPT, |
226 | | - created_at: "", |
227 | | - id: chatMessages.length + "", |
228 | | - image_paths: [], |
229 | | - model: payload.chatSettings.model, |
230 | | - role: "system", |
231 | | - sequence_number: chatMessages.length, |
232 | | - updated_at: "", |
233 | | - user_id: "" |
| 190 | + let rawParts = [] |
| 191 | + if(!Array.isArray(message.content)) { |
| 192 | + rawParts.push({type: 'text', text: message.content}) |
| 193 | + } else { |
| 194 | + rawParts = message.content |
234 | 195 | } |
235 | 196 |
|
236 | | - finalMessages.unshift(tempSystemMessage) |
| 197 | + for(let i = 0; i < rawParts.length; i++) { |
| 198 | + let rawPart = rawParts[i] |
237 | 199 |
|
238 | | - let GOOGLE_FORMATTED_MESSAGES = [] |
239 | | - |
240 | | - if (chatSettings.model === "gemini-pro") { |
241 | | - GOOGLE_FORMATTED_MESSAGES = [ |
242 | | - { |
243 | | - role: "user", |
244 | | - parts: finalMessages[0].content |
245 | | - }, |
246 | | - { |
247 | | - role: "model", |
248 | | - parts: "I will follow your instructions." |
249 | | - } |
250 | | - ] |
251 | | - |
252 | | - for (let i = 1; i < finalMessages.length; i++) { |
253 | | - GOOGLE_FORMATTED_MESSAGES.push({ |
254 | | - role: finalMessages[i].role === "user" ? "user" : "model", |
255 | | - parts: finalMessages[i].content as string |
256 | | - }) |
257 | | - } |
258 | | - |
259 | | - return GOOGLE_FORMATTED_MESSAGES |
260 | | - } else if ((chatSettings.model = "gemini-pro-vision")) { |
261 | | - // Gemini Pro Vision doesn't currently support messages |
262 | | - async function fileToGenerativePart(file: File) { |
263 | | - const base64EncodedDataPromise = new Promise(resolve => { |
264 | | - const reader = new FileReader() |
265 | | - |
266 | | - reader.onloadend = () => { |
267 | | - if (typeof reader.result === "string") { |
268 | | - resolve(reader.result.split(",")[1]) |
269 | | - } |
270 | | - } |
271 | | - |
272 | | - reader.readAsDataURL(file) |
273 | | - }) |
274 | | - |
275 | | - return { |
| 200 | + if(rawPart.type == 'text') { |
| 201 | + adaptedParts.push({text: rawPart.text}) |
| 202 | + } else if(rawPart.type === 'image_url') { |
| 203 | + adaptedParts.push({ |
276 | 204 | inlineData: { |
277 | | - data: await base64EncodedDataPromise, |
278 | | - mimeType: file.type |
| 205 | + data: getBase64FromDataURL(rawPart.image_url.url), |
| 206 | + mimeType: getMediaTypeFromDataURL(rawPart.image_url.url), |
279 | 207 | } |
280 | | - } |
| 208 | + }) |
281 | 209 | } |
| 210 | + } |
282 | 211 |
|
283 | | - let prompt = "" |
| 212 | + let role = 'user' |
| 213 | + if(["user", "system"].includes(message.role)) { |
| 214 | + role = 'user' |
| 215 | + } else if(message.role === 'assistant') { |
| 216 | + role = 'model' |
| 217 | + } |
284 | 218 |
|
285 | | - for (let i = 0; i < finalMessages.length; i++) { |
286 | | - prompt += `${finalMessages[i].role}:\n${finalMessages[i].content}\n\n` |
287 | | - } |
| 219 | + return { |
| 220 | + role: role, |
| 221 | + parts: adaptedParts |
| 222 | + } |
| 223 | +} |
288 | 224 |
|
289 | | - const files = messageImageFiles.map(file => file.file) |
290 | | - const imageParts = await Promise.all( |
291 | | - files.map(file => |
292 | | - file ? fileToGenerativePart(file) : Promise.resolve(null) |
293 | | - ) |
294 | | - ) |
295 | | - |
296 | | - // FIX: Hacky until chat messages are supported |
297 | | - return [ |
298 | | - { |
299 | | - prompt, |
300 | | - imageParts |
301 | | - } |
| 225 | +function adaptMessagesForGeminiVision( |
| 226 | + messages: any[] |
| 227 | +) { |
| 228 | + // Gemini Pro Vision cannot process multiple messages |
| 229 | + // Reformat, using all texts and last visual only |
| 230 | + |
| 231 | + const basePrompt = messages[0].parts[0].text |
| 232 | + const baseRole = messages[0].role |
| 233 | + const lastMessage = messages[messages.length-1] |
| 234 | + const visualMessageParts = lastMessage.parts; |
| 235 | + let visualQueryMessages = [{ |
| 236 | + role: "user", |
| 237 | + parts: [ |
| 238 | + `${baseRole}:\n${basePrompt}\n\nuser:\n${visualMessageParts[0].text}\n\n`, |
| 239 | + visualMessageParts.slice(1) |
302 | 240 | ] |
| 241 | + }] |
| 242 | + return visualQueryMessages |
| 243 | +} |
| 244 | + |
| 245 | +export async function adaptMessagesForGoogleGemini( |
| 246 | + payload: ChatPayload, |
| 247 | + messages: any[] |
| 248 | +) { |
| 249 | + let geminiMessages = [] |
| 250 | + for (let i = 0; i < messages.length; i++) { |
| 251 | + let adaptedMessage = adaptSingleMessageForGoogleGemini(messages[i]) |
| 252 | + geminiMessages.push(adaptedMessage) |
303 | 253 | } |
304 | 254 |
|
305 | | - return finalMessages |
| 255 | + if(payload.chatSettings.model === "gemini-pro-vision") { |
| 256 | + geminiMessages = adaptMessagesForGeminiVision(geminiMessages) |
| 257 | + } |
| 258 | + return geminiMessages |
306 | 259 | } |
| 260 | + |
0 commit comments