Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ const lmStudioSchema = baseProviderSettingsSchema.extend({
const geminiSchema = apiModelIdProviderModelSchema.extend({
geminiApiKey: z.string().optional(),
googleGeminiBaseUrl: z.string().optional(),
geminiDisableIntermediateReasoning: z.boolean().optional(),
})

const geminiCliSchema = apiModelIdProviderModelSchema.extend({
Expand Down
113 changes: 113 additions & 0 deletions src/api/providers/__tests__/gemini.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,119 @@ describe("GeminiHandler", () => {
)
})

it("should handle reasoning chunks correctly when intermediate reasoning is enabled", async () => {
// Setup the mock implementation to return an async generator with reasoning chunks
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield {
candidates: [
{
content: {
parts: [{ thought: true, text: "Let me think about this..." }, { text: "Hello" }],
},
},
],
}
yield {
candidates: [
{
content: {
parts: [{ thought: true, text: "I need to consider..." }, { text: " world!" }],
},
},
],
}
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, thoughtsTokenCount: 20 } }
},
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have 6 chunks: 2 reasoning + 2 text + 2 reasoning + 2 text + usage
expect(chunks.length).toBe(5)
expect(chunks[0]).toEqual({ type: "reasoning", text: "Let me think about this..." })
expect(chunks[1]).toEqual({ type: "text", text: "Hello" })
expect(chunks[2]).toEqual({ type: "reasoning", text: "I need to consider..." })
expect(chunks[3]).toEqual({ type: "text", text: " world!" })
expect(chunks[4]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 5,
reasoningTokens: 20,
})
})

it("should suppress reasoning chunks when geminiDisableIntermediateReasoning is enabled", async () => {
// Create a new handler with the setting enabled
const handlerWithDisabledReasoning = new GeminiHandler({
apiKey: "test-key",
apiModelId: GEMINI_20_FLASH_THINKING_NAME,
geminiApiKey: "test-key",
geminiDisableIntermediateReasoning: true,
})

// Replace the client with our mock
handlerWithDisabledReasoning["client"] = {
models: {
generateContentStream: vitest.fn(),
generateContent: vitest.fn(),
getGenerativeModel: vitest.fn(),
},
} as any

// Setup the mock implementation to return an async generator with reasoning chunks
;(handlerWithDisabledReasoning["client"].models.generateContentStream as any).mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield {
candidates: [
{
content: {
parts: [{ thought: true, text: "Let me think about this..." }, { text: "Hello" }],
},
},
],
}
yield {
candidates: [
{
content: {
parts: [{ thought: true, text: "I need to consider..." }, { text: " world!" }],
},
},
],
}
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, thoughtsTokenCount: 20 } }
},
})

const stream = handlerWithDisabledReasoning.createMessage(systemPrompt, mockMessages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have only 3 chunks: 2 text + usage (reasoning chunks should be suppressed)
expect(chunks.length).toBe(3)
expect(chunks[0]).toEqual({ type: "text", text: "Hello" })
expect(chunks[1]).toEqual({ type: "text", text: " world!" })
expect(chunks[2]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 5,
reasoningTokens: 20,
})

// Verify no reasoning chunks were yielded
const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning")
expect(reasoningChunks.length).toBe(0)
})

it("should handle API errors", async () => {
const mockError = new Error("Gemini API error")
;(handler["client"].models.generateContentStream as any).mockRejectedValue(mockError)
Expand Down
3 changes: 2 additions & 1 deletion src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
for (const part of candidate.content.parts) {
if (part.thought) {
// This is a thinking/reasoning part
if (part.text) {
// Only yield reasoning chunks if intermediate reasoning is not disabled
if (part.text && !this.options.geminiDisableIntermediateReasoning) {
yield { type: "reasoning", text: part.text }
}
} else {
Expand Down