Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions src/api/providers/__tests__/gemini.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types"
import { t } from "i18next"
import { GeminiHandler } from "../gemini"

// Mock the translation function
vitest.mock("i18next", () => ({
t: vitest.fn((key: string) => {
if (key === "common:errors.gemini.sources") return "Sources:"
if (key === "common:errors.gemini.generate_complete_prompt") return "Gemini completion error: {{error}}"
return key
}),
}))

const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"

describe("GeminiHandler", () => {
Expand Down Expand Up @@ -102,6 +111,155 @@ describe("GeminiHandler", () => {
}
}).rejects.toThrow()
})

it("should integrate grounding sources into the assistant message", async () => {
// Setup the mock implementation to return an async generator with grounding metadata
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: "Here is some information about AI." }],
},
groundingMetadata: {
groundingChunks: [
{ web: { uri: "https://example.com/ai-info" } },
{ web: { uri: "https://example.com/ai-research" } },
],
},
},
],
}
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 15 } }
},
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have 3 chunks: main content, sources, and usage info
expect(chunks.length).toBe(3)
expect(chunks[0]).toEqual({ type: "text", text: "Here is some information about AI." })
expect(chunks[1]).toEqual({
type: "text",
text: "\n\nSources: [1](https://example.com/ai-info), [2](https://example.com/ai-research)",
})
expect(chunks[2]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 15 })
})

it("should handle grounding metadata without web sources", async () => {
// Setup the mock implementation with grounding metadata but no web sources
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: "Response without web sources." }],
},
groundingMetadata: {
groundingChunks: [{ someOtherSource: { data: "non-web-source" } }],
},
},
],
}
yield { usageMetadata: { promptTokenCount: 5, candidatesTokenCount: 8 } }
},
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have 2 chunks: main content and usage info (no sources since no web URIs)
expect(chunks.length).toBe(2)
expect(chunks[0]).toEqual({ type: "text", text: "Response without web sources." })
expect(chunks[1]).toEqual({ type: "usage", inputTokens: 5, outputTokens: 8 })
})

it("should not yield sources when no content is generated", async () => {
// Setup the mock implementation with grounding metadata but no content
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield {
candidates: [
{
groundingMetadata: {
groundingChunks: [{ web: { uri: "https://example.com/source" } }],
},
},
],
}
yield { usageMetadata: { promptTokenCount: 5, candidatesTokenCount: 0 } }
},
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Should only have usage info, no sources since no content was yielded
expect(chunks.length).toBe(1)
expect(chunks[0]).toEqual({ type: "usage", inputTokens: 5, outputTokens: 0 })
})

it("should handle multiple text chunks with grounding sources", async () => {
// Setup the mock implementation with multiple text chunks and grounding
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: "First part of response" }],
},
},
],
}
yield {
candidates: [
{
content: {
parts: [{ text: " and second part." }],
},
groundingMetadata: {
groundingChunks: [{ web: { uri: "https://example.com/source1" } }],
},
},
],
}
yield { usageMetadata: { promptTokenCount: 12, candidatesTokenCount: 18 } }
},
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []

for await (const chunk of stream) {
chunks.push(chunk)
}

// Should have 4 chunks: two text chunks, sources, and usage info
expect(chunks.length).toBe(4)
expect(chunks[0]).toEqual({ type: "text", text: "First part of response" })
expect(chunks[1]).toEqual({ type: "text", text: " and second part." })
expect(chunks[2]).toEqual({
type: "text",
text: "\n\nSources: [1](https://example.com/source1)",
})
expect(chunks[3]).toEqual({ type: "usage", inputTokens: 12, outputTokens: 18 })
})
})

describe("completePrompt", () => {
Expand Down Expand Up @@ -143,6 +301,38 @@ describe("GeminiHandler", () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})

it("should integrate grounding sources in completePrompt", async () => {
// Mock the response with grounding metadata
;(handler["client"].models.generateContent as any).mockResolvedValue({
text: "AI is a fascinating field of study.",
candidates: [
{
groundingMetadata: {
groundingChunks: [
{ web: { uri: "https://example.com/ai-study" } },
{ web: { uri: "https://example.com/ai-research" } },
],
},
},
],
})

const result = await handler.completePrompt("Tell me about AI")
expect(result).toBe(
"AI is a fascinating field of study.\n\nSources: [1](https://example.com/ai-study), [2](https://example.com/ai-research)",
)
})

it("should handle completePrompt without grounding sources", async () => {
// Mock the response without grounding metadata
;(handler["client"].models.generateContent as any).mockResolvedValue({
text: "Simple response without sources.",
})

const result = await handler.completePrompt("Simple question")
expect(result).toBe("Simple response without sources.")
})
})

describe("getModel", () => {
Expand Down
12 changes: 10 additions & 2 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl

let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined
let pendingGroundingMetadata: GroundingMetadata | undefined
let lastTextChunk: string | null = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable lastTextChunk is assigned but never used. Consider removing it to clean up the code.

let hasYieldedContent = false

for await (const chunk of result) {
// Process candidates and their parts to separate thoughts from content
Expand All @@ -114,6 +116,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
} else {
// This is regular content
if (part.text) {
lastTextChunk = part.text
hasYieldedContent = true
yield { type: "text", text: part.text }
}
}
Expand All @@ -123,6 +127,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl

// Fallback to the original text property if no candidates structure
else if (chunk.text) {
lastTextChunk = chunk.text
hasYieldedContent = true
yield { type: "text", text: chunk.text }
}

Expand All @@ -131,10 +137,12 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
}
}

if (pendingGroundingMetadata) {
// If we have grounding metadata and content was yielded, append sources to the last text chunk
if (pendingGroundingMetadata && hasYieldedContent) {
const citations = this.extractCitationsOnly(pendingGroundingMetadata)
if (citations) {
yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` }
const sourcesText = `\n\n${t("common:errors.gemini.sources")} ${citations}`
yield { type: "text", text: sourcesText }
}
}

Expand Down