Skip to content

Commit a9174a8

Browse files
committed
fix: integrate Gemini grounding sources into assistant message
- Modified streaming logic to append grounding sources to the last text chunk instead of yielding as separate message - Added tracking of content yielding to ensure sources only appear when content exists - Added comprehensive test coverage for grounding functionality including edge cases - Fixes issue where grounding sources appeared as separate message bubbles Fixes #6372
1 parent b117c0f commit a9174a8

File tree

2 files changed

+200
-2
lines changed

2 files changed

+200
-2
lines changed

src/api/providers/__tests__/gemini.spec.ts

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types"
77
import { t } from "i18next"
88
import { GeminiHandler } from "../gemini"
99

10+
// Mock the translation function
11+
vitest.mock("i18next", () => ({
12+
t: vitest.fn((key: string) => {
13+
if (key === "common:errors.gemini.sources") return "Sources:"
14+
if (key === "common:errors.gemini.generate_complete_prompt") return "Gemini completion error: {{error}}"
15+
return key
16+
}),
17+
}))
18+
1019
const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
1120

1221
describe("GeminiHandler", () => {
@@ -102,6 +111,155 @@ describe("GeminiHandler", () => {
102111
}
103112
}).rejects.toThrow()
104113
})
114+
115+
it("should integrate grounding sources into the assistant message", async () => {
116+
// Setup the mock implementation to return an async generator with grounding metadata
117+
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
118+
[Symbol.asyncIterator]: async function* () {
119+
yield {
120+
candidates: [
121+
{
122+
content: {
123+
parts: [{ text: "Here is some information about AI." }],
124+
},
125+
groundingMetadata: {
126+
groundingChunks: [
127+
{ web: { uri: "https://example.com/ai-info" } },
128+
{ web: { uri: "https://example.com/ai-research" } },
129+
],
130+
},
131+
},
132+
],
133+
}
134+
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 15 } }
135+
},
136+
})
137+
138+
const stream = handler.createMessage(systemPrompt, mockMessages)
139+
const chunks = []
140+
141+
for await (const chunk of stream) {
142+
chunks.push(chunk)
143+
}
144+
145+
// Should have 3 chunks: main content, sources, and usage info
146+
expect(chunks.length).toBe(3)
147+
expect(chunks[0]).toEqual({ type: "text", text: "Here is some information about AI." })
148+
expect(chunks[1]).toEqual({
149+
type: "text",
150+
text: "\n\nSources: [1](https://example.com/ai-info), [2](https://example.com/ai-research)",
151+
})
152+
expect(chunks[2]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 15 })
153+
})
154+
155+
it("should handle grounding metadata without web sources", async () => {
156+
// Setup the mock implementation with grounding metadata but no web sources
157+
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
158+
[Symbol.asyncIterator]: async function* () {
159+
yield {
160+
candidates: [
161+
{
162+
content: {
163+
parts: [{ text: "Response without web sources." }],
164+
},
165+
groundingMetadata: {
166+
groundingChunks: [{ someOtherSource: { data: "non-web-source" } }],
167+
},
168+
},
169+
],
170+
}
171+
yield { usageMetadata: { promptTokenCount: 5, candidatesTokenCount: 8 } }
172+
},
173+
})
174+
175+
const stream = handler.createMessage(systemPrompt, mockMessages)
176+
const chunks = []
177+
178+
for await (const chunk of stream) {
179+
chunks.push(chunk)
180+
}
181+
182+
// Should have 2 chunks: main content and usage info (no sources since no web URIs)
183+
expect(chunks.length).toBe(2)
184+
expect(chunks[0]).toEqual({ type: "text", text: "Response without web sources." })
185+
expect(chunks[1]).toEqual({ type: "usage", inputTokens: 5, outputTokens: 8 })
186+
})
187+
188+
it("should not yield sources when no content is generated", async () => {
189+
// Setup the mock implementation with grounding metadata but no content
190+
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
191+
[Symbol.asyncIterator]: async function* () {
192+
yield {
193+
candidates: [
194+
{
195+
groundingMetadata: {
196+
groundingChunks: [{ web: { uri: "https://example.com/source" } }],
197+
},
198+
},
199+
],
200+
}
201+
yield { usageMetadata: { promptTokenCount: 5, candidatesTokenCount: 0 } }
202+
},
203+
})
204+
205+
const stream = handler.createMessage(systemPrompt, mockMessages)
206+
const chunks = []
207+
208+
for await (const chunk of stream) {
209+
chunks.push(chunk)
210+
}
211+
212+
// Should only have usage info, no sources since no content was yielded
213+
expect(chunks.length).toBe(1)
214+
expect(chunks[0]).toEqual({ type: "usage", inputTokens: 5, outputTokens: 0 })
215+
})
216+
217+
it("should handle multiple text chunks with grounding sources", async () => {
218+
// Setup the mock implementation with multiple text chunks and grounding
219+
;(handler["client"].models.generateContentStream as any).mockResolvedValue({
220+
[Symbol.asyncIterator]: async function* () {
221+
yield {
222+
candidates: [
223+
{
224+
content: {
225+
parts: [{ text: "First part of response" }],
226+
},
227+
},
228+
],
229+
}
230+
yield {
231+
candidates: [
232+
{
233+
content: {
234+
parts: [{ text: " and second part." }],
235+
},
236+
groundingMetadata: {
237+
groundingChunks: [{ web: { uri: "https://example.com/source1" } }],
238+
},
239+
},
240+
],
241+
}
242+
yield { usageMetadata: { promptTokenCount: 12, candidatesTokenCount: 18 } }
243+
},
244+
})
245+
246+
const stream = handler.createMessage(systemPrompt, mockMessages)
247+
const chunks = []
248+
249+
for await (const chunk of stream) {
250+
chunks.push(chunk)
251+
}
252+
253+
// Should have 4 chunks: two text chunks, sources, and usage info
254+
expect(chunks.length).toBe(4)
255+
expect(chunks[0]).toEqual({ type: "text", text: "First part of response" })
256+
expect(chunks[1]).toEqual({ type: "text", text: " and second part." })
257+
expect(chunks[2]).toEqual({
258+
type: "text",
259+
text: "\n\nSources: [1](https://example.com/source1)",
260+
})
261+
expect(chunks[3]).toEqual({ type: "usage", inputTokens: 12, outputTokens: 18 })
262+
})
105263
})
106264

107265
describe("completePrompt", () => {
@@ -143,6 +301,38 @@ describe("GeminiHandler", () => {
143301
const result = await handler.completePrompt("Test prompt")
144302
expect(result).toBe("")
145303
})
304+
305+
it("should integrate grounding sources in completePrompt", async () => {
306+
// Mock the response with grounding metadata
307+
;(handler["client"].models.generateContent as any).mockResolvedValue({
308+
text: "AI is a fascinating field of study.",
309+
candidates: [
310+
{
311+
groundingMetadata: {
312+
groundingChunks: [
313+
{ web: { uri: "https://example.com/ai-study" } },
314+
{ web: { uri: "https://example.com/ai-research" } },
315+
],
316+
},
317+
},
318+
],
319+
})
320+
321+
const result = await handler.completePrompt("Tell me about AI")
322+
expect(result).toBe(
323+
"AI is a fascinating field of study.\n\nSources: [1](https://example.com/ai-study), [2](https://example.com/ai-research)",
324+
)
325+
})
326+
327+
it("should handle completePrompt without grounding sources", async () => {
328+
// Mock the response without grounding metadata
329+
;(handler["client"].models.generateContent as any).mockResolvedValue({
330+
text: "Simple response without sources.",
331+
})
332+
333+
const result = await handler.completePrompt("Simple question")
334+
expect(result).toBe("Simple response without sources.")
335+
})
146336
})
147337

148338
describe("getModel", () => {

src/api/providers/gemini.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
9494

9595
let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined
9696
let pendingGroundingMetadata: GroundingMetadata | undefined
97+
let lastTextChunk: string | null = null
98+
let hasYieldedContent = false
9799

98100
for await (const chunk of result) {
99101
// Process candidates and their parts to separate thoughts from content
@@ -114,6 +116,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
114116
} else {
115117
// This is regular content
116118
if (part.text) {
119+
lastTextChunk = part.text
120+
hasYieldedContent = true
117121
yield { type: "text", text: part.text }
118122
}
119123
}
@@ -123,6 +127,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
123127

124128
// Fallback to the original text property if no candidates structure
125129
else if (chunk.text) {
130+
lastTextChunk = chunk.text
131+
hasYieldedContent = true
126132
yield { type: "text", text: chunk.text }
127133
}
128134

@@ -131,10 +137,12 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
131137
}
132138
}
133139

134-
if (pendingGroundingMetadata) {
140+
// If we have grounding metadata and content was yielded, append sources to the last text chunk
141+
if (pendingGroundingMetadata && hasYieldedContent) {
135142
const citations = this.extractCitationsOnly(pendingGroundingMetadata)
136143
if (citations) {
137-
yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` }
144+
const sourcesText = `\n\n${t("common:errors.gemini.sources")} ${citations}`
145+
yield { type: "text", text: sourcesText }
138146
}
139147
}
140148

0 commit comments

Comments
 (0)