diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index 812c1ae1a64d..45302eff370a 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -164,6 +164,145 @@ describe("GeminiHandler", () => { }) }) + describe("grounding bypass for code generation", () => { + it("should disable grounding when code generation context is detected", async () => { + const codeHandler = new GeminiHandler({ + apiModelId: GEMINI_20_FLASH_THINKING_NAME, + geminiApiKey: "test-key", + enableGrounding: true, // Grounding is enabled + }) + + // Mock the client's generateContentStream method + const mockGenerateContentStream = vi.fn().mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { + text: "fruits = ['apple', 'banana']\nmyfruit = fruits[0]", + usageMetadata: { promptTokenCount: 100, candidatesTokenCount: 50 }, + } + }, + }) + codeHandler["client"].models.generateContentStream = mockGenerateContentStream + + // Test message that includes code generation request + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text", + text: "Create a python file that creates a list of 10 fruits and put the first one in the variable 'myfruit'", + }, + ], + }, + ] + + const stream = codeHandler.createMessage("System prompt", messages) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify that googleSearch was NOT added to tools (grounding disabled) + const callArgs = mockGenerateContentStream.mock.calls[0][0] + const tools = callArgs.config.tools || [] + const hasGrounding = tools.some((tool: any) => "googleSearch" in tool) + expect(hasGrounding).toBe(false) + }) + + it("should enable grounding for non-code contexts when enableGrounding is true", async () => { + const nonCodeHandler = new GeminiHandler({ + apiModelId: GEMINI_20_FLASH_THINKING_NAME, + geminiApiKey: "test-key", + enableGrounding: true, + }) + + // Mock the client's generateContentStream method + const mockGenerateContentStream = vi.fn().mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { + text: "The weather today is sunny.", + usageMetadata: { promptTokenCount: 100, candidatesTokenCount: 50 }, + } + }, + }) + nonCodeHandler["client"].models.generateContentStream = mockGenerateContentStream + + // Test message without code generation context + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text", + text: "What's the weather like today?", + }, + ], + }, + ] + + const stream = nonCodeHandler.createMessage("System prompt", messages) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify that googleSearch WAS added to tools (grounding enabled) + const callArgs = mockGenerateContentStream.mock.calls[0][0] + const tools = callArgs.config.tools || [] + const hasGrounding = tools.some((tool: any) => "googleSearch" in tool) + expect(hasGrounding).toBe(true) + }) + + it("should correctly identify various code generation patterns", async () => { + const handler = new GeminiHandler({ + apiModelId: GEMINI_20_FLASH_THINKING_NAME, + geminiApiKey: "test-key", + }) + + // Test various code generation patterns + const codePatterns = [ + "Create a python file with a list", + "Write a javascript function", + "Generate code snippet", + "Implement a class method", + "def my_function():", + "function getData() {", + "fruits[0] = 'apple'", + "array[5]", + ] + + for (const pattern of codePatterns) { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: pattern }], + }, + ] + const result = handler["isCodeGenerationContext"](messages) + expect(result).toBe(true) + } + + // Test non-code patterns + const nonCodePatterns = [ + "What's the weather?", + "Explain quantum physics", + "Tell me a story", + "How do I cook pasta?", + ] + + for (const pattern of nonCodePatterns) { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: pattern }], + }, + ] + const result = handler["isCodeGenerationContext"](messages) + expect(result).toBe(false) + } + }) + }) + describe("calculateCost", () => { // Mock ModelInfo based on gemini-1.5-flash-latest pricing (per 1M tokens) // Removed 'id' and 'name' as they are not part of ModelInfo type directly diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 573adda879ec..9fc8520e08be 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -60,6 +60,52 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl : new GoogleGenAI({ apiKey }) } + /** + * Detects if the conversation context suggests code generation. + * This helps prevent Gemini's grounding feature from incorrectly + * removing array indices like [0] which it may interpret as citations. + */ + private isCodeGenerationContext(messages: Anthropic.Messages.MessageParam[]): boolean { + // Keywords that strongly suggest code generation + const codeKeywords = [ + "create.*(?:file|script|function|class|method|code|program)", + "write.*(?:file|script|function|class|method|code|program)", + "generate.*(?:file|script|function|class|method|code|program)", + "implement.*(?:function|class|method|algorithm)", + "python file", + "javascript file", + "typescript file", + "code snippet", + "code example", + "def\\s+\\w+\\s*\\(", // Python function definition + "function\\s+\\w+\\s*\\(", // JavaScript function + "class\\s+\\w+", // Class definition + "\\[\\d+\\]", // Array index patterns + "array\\[", + "list\\[", + "fruits\\[0\\]", // Specific to the reported issue + ] + + const codePattern = new RegExp(codeKeywords.join("|"), "i") + + // Check recent messages for code-related content + const recentMessages = messages.slice(-5) // Check last 5 messages + + for (const message of recentMessages) { + if (Array.isArray(message.content)) { + for (const block of message.content) { + if (block.type === "text" && codePattern.test(block.text)) { + return true + } + } + } else if (typeof message.content === "string" && codePattern.test(message.content)) { + return true + } + } + + return false + } + async *createMessage( systemInstruction: string, messages: Anthropic.Messages.MessageParam[], @@ -74,7 +120,10 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl tools.push({ urlContext: {} }) } - if (this.options.enableGrounding) { + // Only enable grounding if it's not a code generation context + // This prevents Gemini from misinterpreting array indices like [0] as citation markers + const isCodeContext = this.isCodeGenerationContext(messages) + if (this.options.enableGrounding && !isCodeContext) { tools.push({ googleSearch: {} }) } @@ -217,7 +266,13 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl if (this.options.enableUrlContext) { tools.push({ urlContext: {} }) } - if (this.options.enableGrounding) { + + // Check if the prompt suggests code generation + const isCodeContext = this.isCodeGenerationContext([ + { role: "user", content: [{ type: "text", text: prompt }] }, + ]) + + if (this.options.enableGrounding && !isCodeContext) { tools.push({ googleSearch: {} }) } const promptConfig: GenerateContentConfig = {