Skip to content

Commit c206da4

Browse files
HahaBilldaniel-lxsmrubens
authored
fix: Tackling Race/State condition issue by Changing the Code Design for Gemini Grounding Sources (#7434)
Co-authored-by: daniel-lxs <[email protected]> Co-authored-by: Matt Rubens <[email protected]>
1 parent 687b379 commit c206da4

File tree

5 files changed

+301
-24
lines changed

5 files changed

+301
-24
lines changed

src/api/providers/__tests__/gemini-handler.spec.ts

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ describe("GeminiHandler backend support", () => {
8585
groundingMetadata: {
8686
groundingChunks: [
8787
{ web: null }, // Missing URI
88-
{ web: { uri: "https://example.com" } }, // Valid
88+
{ web: { uri: "https://example.com", title: "Example Site" } }, // Valid
8989
{}, // Missing web property entirely
9090
],
9191
},
@@ -105,13 +105,20 @@ describe("GeminiHandler backend support", () => {
105105
messages.push(chunk)
106106
}
107107

108-
// Should only include valid citations
109-
const sourceMessage = messages.find((m) => m.type === "text" && m.text?.includes("[2]"))
110-
expect(sourceMessage).toBeDefined()
111-
if (sourceMessage && "text" in sourceMessage) {
112-
expect(sourceMessage.text).toContain("https://example.com")
113-
expect(sourceMessage.text).not.toContain("[1]")
114-
expect(sourceMessage.text).not.toContain("[3]")
108+
// Should have the text response
109+
const textMessage = messages.find((m) => m.type === "text")
110+
expect(textMessage).toBeDefined()
111+
if (textMessage && "text" in textMessage) {
112+
expect(textMessage.text).toBe("test response")
113+
}
114+
115+
// Should have grounding chunk with only valid sources
116+
const groundingMessage = messages.find((m) => m.type === "grounding")
117+
expect(groundingMessage).toBeDefined()
118+
if (groundingMessage && "sources" in groundingMessage) {
119+
expect(groundingMessage.sources).toHaveLength(1)
120+
expect(groundingMessage.sources[0].url).toBe("https://example.com")
121+
expect(groundingMessage.sources[0].title).toBe("Example Site")
115122
}
116123
})
117124

src/api/providers/gemini.ts

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import { safeJsonParse } from "../../shared/safeJsonParse"
1515

1616
import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
1717
import { t } from "i18next"
18-
import type { ApiStream } from "../transform/stream"
18+
import type { ApiStream, GroundingSource } from "../transform/stream"
1919
import { getModelParams } from "../transform/model-params"
2020

2121
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
@@ -132,9 +132,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
132132
}
133133

134134
if (pendingGroundingMetadata) {
135-
const citations = this.extractCitationsOnly(pendingGroundingMetadata)
136-
if (citations) {
137-
yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` }
135+
const sources = this.extractGroundingSources(pendingGroundingMetadata)
136+
if (sources.length > 0) {
137+
yield { type: "grounding", sources }
138138
}
139139
}
140140

@@ -175,28 +175,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
175175
return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params }
176176
}
177177

178-
private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null {
178+
private extractGroundingSources(groundingMetadata?: GroundingMetadata): GroundingSource[] {
179179
const chunks = groundingMetadata?.groundingChunks
180180

181181
if (!chunks) {
182-
return null
182+
return []
183183
}
184184

185-
const citationLinks = chunks
186-
.map((chunk, i) => {
185+
return chunks
186+
.map((chunk): GroundingSource | null => {
187187
const uri = chunk.web?.uri
188+
const title = chunk.web?.title || uri || "Unknown Source"
189+
188190
if (uri) {
189-
return `[${i + 1}](${uri})`
191+
return {
192+
title,
193+
url: uri,
194+
}
190195
}
191196
return null
192197
})
193-
.filter((link): link is string => link !== null)
198+
.filter((source): source is GroundingSource => source !== null)
199+
}
200+
201+
private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null {
202+
const sources = this.extractGroundingSources(groundingMetadata)
194203

195-
if (citationLinks.length > 0) {
196-
return citationLinks.join(", ")
204+
if (sources.length === 0) {
205+
return null
197206
}
198207

199-
return null
208+
const citationLinks = sources.map((source, i) => `[${i + 1}](${source.url})`)
209+
return citationLinks.join(", ")
200210
}
201211

202212
async completePrompt(prompt: string): Promise<string> {

src/api/transform/stream.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
export type ApiStream = AsyncGenerator<ApiStreamChunk>
22

3-
export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError
3+
export type ApiStreamChunk =
4+
| ApiStreamTextChunk
5+
| ApiStreamUsageChunk
6+
| ApiStreamReasoningChunk
7+
| ApiStreamGroundingChunk
8+
| ApiStreamError
49

510
export interface ApiStreamError {
611
type: "error"
@@ -27,3 +32,14 @@ export interface ApiStreamUsageChunk {
2732
reasoningTokens?: number
2833
totalCost?: number
2934
}
35+
36+
export interface ApiStreamGroundingChunk {
37+
type: "grounding"
38+
sources: GroundingSource[]
39+
}
40+
41+
export interface GroundingSource {
42+
title: string
43+
url: string
44+
snippet?: string
45+
}

src/core/task/Task.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import { CloudService, BridgeOrchestrator } from "@roo-code/cloud"
4141

4242
// api
4343
import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api"
44-
import { ApiStream } from "../../api/transform/stream"
44+
import { ApiStream, GroundingSource } from "../../api/transform/stream"
4545
import { maybeRemoveImageBlocks } from "../../api/transform/image-cleaning"
4646

4747
// shared
@@ -1897,7 +1897,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
18971897
this.didFinishAbortingStream = true
18981898
}
18991899

1900-
// Reset streaming state.
1900+
// Reset streaming state for each new API request
19011901
this.currentStreamingContentIndex = 0
19021902
this.currentStreamingDidCheckpoint = false
19031903
this.assistantMessageContent = []
@@ -1918,6 +1918,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
19181918
const stream = this.attemptApiRequest()
19191919
let assistantMessage = ""
19201920
let reasoningMessage = ""
1921+
let pendingGroundingSources: GroundingSource[] = []
19211922
this.isStreaming = true
19221923

19231924
try {
@@ -1944,6 +1945,13 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
19441945
cacheReadTokens += chunk.cacheReadTokens ?? 0
19451946
totalCost = chunk.totalCost
19461947
break
1948+
case "grounding":
1949+
// Handle grounding sources separately from regular content
1950+
// to prevent state persistence issues - store them separately
1951+
if (chunk.sources && chunk.sources.length > 0) {
1952+
pendingGroundingSources.push(...chunk.sources)
1953+
}
1954+
break
19471955
case "text": {
19481956
assistantMessage += chunk.text
19491957

@@ -2237,6 +2245,16 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
22372245
let didEndLoop = false
22382246

22392247
if (assistantMessage.length > 0) {
2248+
// Display grounding sources to the user if they exist
2249+
if (pendingGroundingSources.length > 0) {
2250+
const citationLinks = pendingGroundingSources.map((source, i) => `[${i + 1}](${source.url})`)
2251+
const sourcesText = `${t("common:gemini.sources")} ${citationLinks.join(", ")}`
2252+
2253+
await this.say("text", sourcesText, undefined, false, undefined, undefined, {
2254+
isNonInteractive: true,
2255+
})
2256+
}
2257+
22402258
await this.addToApiConversationHistory({
22412259
role: "assistant",
22422260
content: [{ type: "text", text: assistantMessage }],

0 commit comments

Comments
 (0)