diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index e940ececd1d..b1402dab963 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -157,6 +157,7 @@ const lmStudioSchema = baseProviderSettingsSchema.extend({ const geminiSchema = apiModelIdProviderModelSchema.extend({ geminiApiKey: z.string().optional(), googleGeminiBaseUrl: z.string().optional(), + geminiDisableIntermediateReasoning: z.boolean().optional(), }) const geminiCliSchema = apiModelIdProviderModelSchema.extend({ diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index 8a7fd24fe36..a271fc68565 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -89,6 +89,119 @@ describe("GeminiHandler", () => { ) }) + it("should handle reasoning chunks correctly when intermediate reasoning is enabled", async () => { + // Setup the mock implementation to return an async generator with reasoning chunks + ;(handler["client"].models.generateContentStream as any).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { + candidates: [ + { + content: { + parts: [{ thought: true, text: "Let me think about this..." }, { text: "Hello" }], + }, + }, + ], + } + yield { + candidates: [ + { + content: { + parts: [{ thought: true, text: "I need to consider..." }, { text: " world!" }], + }, + }, + ], + } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, thoughtsTokenCount: 20 } } + }, + }) + + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Should have 6 chunks: 2 reasoning + 2 text + 2 reasoning + 2 text + usage + expect(chunks.length).toBe(5) + expect(chunks[0]).toEqual({ type: "reasoning", text: "Let me think about this..." }) + expect(chunks[1]).toEqual({ type: "text", text: "Hello" }) + expect(chunks[2]).toEqual({ type: "reasoning", text: "I need to consider..." }) + expect(chunks[3]).toEqual({ type: "text", text: " world!" }) + expect(chunks[4]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + reasoningTokens: 20, + }) + }) + + it("should suppress reasoning chunks when geminiDisableIntermediateReasoning is enabled", async () => { + // Create a new handler with the setting enabled + const handlerWithDisabledReasoning = new GeminiHandler({ + apiKey: "test-key", + apiModelId: GEMINI_20_FLASH_THINKING_NAME, + geminiApiKey: "test-key", + geminiDisableIntermediateReasoning: true, + }) + + // Replace the client with our mock + handlerWithDisabledReasoning["client"] = { + models: { + generateContentStream: vitest.fn(), + generateContent: vitest.fn(), + getGenerativeModel: vitest.fn(), + }, + } as any + + // Setup the mock implementation to return an async generator with reasoning chunks + ;(handlerWithDisabledReasoning["client"].models.generateContentStream as any).mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { + candidates: [ + { + content: { + parts: [{ thought: true, text: "Let me think about this..." }, { text: "Hello" }], + }, + }, + ], + } + yield { + candidates: [ + { + content: { + parts: [{ thought: true, text: "I need to consider..." }, { text: " world!" }], + }, + }, + ], + } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, thoughtsTokenCount: 20 } } + }, + }) + + const stream = handlerWithDisabledReasoning.createMessage(systemPrompt, mockMessages) + const chunks = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Should have only 3 chunks: 2 text + usage (reasoning chunks should be suppressed) + expect(chunks.length).toBe(3) + expect(chunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(chunks[1]).toEqual({ type: "text", text: " world!" }) + expect(chunks[2]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + reasoningTokens: 20, + }) + + // Verify no reasoning chunks were yielded + const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") + expect(reasoningChunks.length).toBe(0) + }) + it("should handle API errors", async () => { const mockError = new Error("Gemini API error") ;(handler["client"].models.generateContentStream as any).mockRejectedValue(mockError) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 6765c8676d8..8a0e5afe09b 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -89,7 +89,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl for (const part of candidate.content.parts) { if (part.thought) { // This is a thinking/reasoning part - if (part.text) { + // Only yield reasoning chunks if intermediate reasoning is not disabled + if (part.text && !this.options.geminiDisableIntermediateReasoning) { yield { type: "reasoning", text: part.text } } } else {