Skip to content

Commit edd6f3d

Browse files
committed
fix: update gemini grounding test to match new architecture
- Fixed failing test to expect grounding chunk instead of text with citations - Added inline comments to regex patterns for source stripping in Task.ts - Added test coverage for grounding source handling to prevent regression
1 parent a250a7e commit edd6f3d

File tree

3 files changed

+243
-10
lines changed

3 files changed

+243
-10
lines changed

src/api/providers/__tests__/gemini-handler.spec.ts

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ describe("GeminiHandler backend support", () => {
8585
groundingMetadata: {
8686
groundingChunks: [
8787
{ web: null }, // Missing URI
88-
{ web: { uri: "https://example.com" } }, // Valid
88+
{ web: { uri: "https://example.com", title: "Example Site" } }, // Valid
8989
{}, // Missing web property entirely
9090
],
9191
},
@@ -105,13 +105,20 @@ describe("GeminiHandler backend support", () => {
105105
messages.push(chunk)
106106
}
107107

108-
// Should only include valid citations
109-
const sourceMessage = messages.find((m) => m.type === "text" && m.text?.includes("[2]"))
110-
expect(sourceMessage).toBeDefined()
111-
if (sourceMessage && "text" in sourceMessage) {
112-
expect(sourceMessage.text).toContain("https://example.com")
113-
expect(sourceMessage.text).not.toContain("[1]")
114-
expect(sourceMessage.text).not.toContain("[3]")
108+
// Should have the text response
109+
const textMessage = messages.find((m) => m.type === "text")
110+
expect(textMessage).toBeDefined()
111+
if (textMessage && "text" in textMessage) {
112+
expect(textMessage.text).toBe("test response")
113+
}
114+
115+
// Should have grounding chunk with only valid sources
116+
const groundingMessage = messages.find((m) => m.type === "grounding")
117+
expect(groundingMessage).toBeDefined()
118+
if (groundingMessage && "sources" in groundingMessage) {
119+
expect(groundingMessage.sources).toHaveLength(1)
120+
expect(groundingMessage.sources[0].url).toBe("https://example.com")
121+
expect(groundingMessage.sources[0].title).toBe("Example Site")
115122
}
116123
})
117124

src/core/task/Task.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,8 +2110,8 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
21102110
if (pendingGroundingSources.length > 0) {
21112111
// Remove any grounding source references that might have been integrated into the message
21122112
cleanAssistantMessage = assistantMessage
2113-
.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "")
2114-
.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "")
2113+
.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") // e.g., "[1] Example Source: https://example.com"
2114+
.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") // e.g., "Sources: [1](url1), [2](url2)"
21152115
.trim()
21162116
}
21172117

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import { describe, it, expect, vi, beforeEach, beforeAll } from "vitest"
2+
import type { ClineProvider } from "../../webview/ClineProvider"
3+
import type { ProviderSettings } from "@roo-code/types"
4+
5+
// Mock vscode module before importing Task
6+
vi.mock("vscode", () => ({
7+
workspace: {
8+
createFileSystemWatcher: vi.fn(() => ({
9+
onDidCreate: vi.fn(),
10+
onDidChange: vi.fn(),
11+
onDidDelete: vi.fn(),
12+
dispose: vi.fn(),
13+
})),
14+
getConfiguration: vi.fn(() => ({
15+
get: vi.fn(() => true),
16+
})),
17+
openTextDocument: vi.fn(),
18+
applyEdit: vi.fn(),
19+
},
20+
RelativePattern: vi.fn((base, pattern) => ({ base, pattern })),
21+
window: {
22+
createOutputChannel: vi.fn(() => ({
23+
appendLine: vi.fn(),
24+
dispose: vi.fn(),
25+
})),
26+
createTextEditorDecorationType: vi.fn(() => ({
27+
dispose: vi.fn(),
28+
})),
29+
showTextDocument: vi.fn(),
30+
activeTextEditor: undefined,
31+
},
32+
Uri: {
33+
file: vi.fn((path) => ({ fsPath: path })),
34+
parse: vi.fn((str) => ({ toString: () => str })),
35+
},
36+
Range: vi.fn(),
37+
Position: vi.fn(),
38+
WorkspaceEdit: vi.fn(() => ({
39+
replace: vi.fn(),
40+
insert: vi.fn(),
41+
delete: vi.fn(),
42+
})),
43+
ViewColumn: {
44+
One: 1,
45+
Two: 2,
46+
Three: 3,
47+
},
48+
}))
49+
50+
// Mock other dependencies
51+
vi.mock("../../services/mcp/McpServerManager", () => ({
52+
McpServerManager: {
53+
getInstance: vi.fn().mockResolvedValue(null),
54+
},
55+
}))
56+
57+
vi.mock("../../integrations/terminal/TerminalRegistry", () => ({
58+
TerminalRegistry: {
59+
releaseTerminalsForTask: vi.fn(),
60+
},
61+
}))
62+
63+
vi.mock("@roo-code/telemetry", () => ({
64+
TelemetryService: {
65+
instance: {
66+
captureTaskCreated: vi.fn(),
67+
captureTaskRestarted: vi.fn(),
68+
captureConversationMessage: vi.fn(),
69+
captureLlmCompletion: vi.fn(),
70+
captureConsecutiveMistakeError: vi.fn(),
71+
},
72+
},
73+
}))
74+
75+
describe("Task grounding sources handling", () => {
76+
let mockProvider: Partial<ClineProvider>
77+
let mockApiConfiguration: ProviderSettings
78+
let Task: any
79+
80+
beforeAll(async () => {
81+
// Import Task after mocks are set up
82+
const taskModule = await import("../Task")
83+
Task = taskModule.Task
84+
})
85+
86+
beforeEach(() => {
87+
// Mock provider with necessary methods
88+
mockProvider = {
89+
postStateToWebview: vi.fn().mockResolvedValue(undefined),
90+
getState: vi.fn().mockResolvedValue({
91+
mode: "code",
92+
experiments: {},
93+
}),
94+
context: {
95+
globalStorageUri: { fsPath: "/test/storage" },
96+
extensionPath: "/test/extension",
97+
} as any,
98+
log: vi.fn(),
99+
updateTaskHistory: vi.fn().mockResolvedValue(undefined),
100+
postMessageToWebview: vi.fn().mockResolvedValue(undefined),
101+
}
102+
103+
mockApiConfiguration = {
104+
apiProvider: "gemini",
105+
geminiApiKey: "test-key",
106+
enableGrounding: true,
107+
} as ProviderSettings
108+
})
109+
110+
it("should strip grounding sources from assistant message before persisting to API history", async () => {
111+
// Create a task instance
112+
const task = new Task({
113+
provider: mockProvider as ClineProvider,
114+
apiConfiguration: mockApiConfiguration,
115+
task: "Test task",
116+
startTask: false,
117+
})
118+
119+
// Mock the API conversation history
120+
task.apiConversationHistory = []
121+
122+
// Simulate an assistant message with grounding sources
123+
const assistantMessageWithSources = `
124+
This is the main response content.
125+
126+
[1] Example Source: https://example.com
127+
[2] Another Source: https://another.com
128+
129+
Sources: [1](https://example.com), [2](https://another.com)
130+
`.trim()
131+
132+
// Mock grounding sources
133+
const mockGroundingSources = [
134+
{ title: "Example Source", url: "https://example.com" },
135+
{ title: "Another Source", url: "https://another.com" },
136+
]
137+
138+
// Spy on addToApiConversationHistory to check what gets persisted
139+
const addToApiHistorySpy = vi.spyOn(task as any, "addToApiConversationHistory")
140+
141+
// Simulate the logic from Task.ts that strips grounding sources
142+
let cleanAssistantMessage = assistantMessageWithSources
143+
if (mockGroundingSources.length > 0) {
144+
cleanAssistantMessage = assistantMessageWithSources
145+
.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") // e.g., "[1] Example Source: https://example.com"
146+
.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") // e.g., "Sources: [1](url1), [2](url2)"
147+
.trim()
148+
}
149+
150+
// Add the cleaned message to API history
151+
await (task as any).addToApiConversationHistory({
152+
role: "assistant",
153+
content: [{ type: "text", text: cleanAssistantMessage }],
154+
})
155+
156+
// Verify that the cleaned message was added without grounding sources
157+
expect(addToApiHistorySpy).toHaveBeenCalledWith({
158+
role: "assistant",
159+
content: [{ type: "text", text: "This is the main response content." }],
160+
})
161+
162+
// Verify the API conversation history contains the cleaned message
163+
expect(task.apiConversationHistory).toHaveLength(1)
164+
expect(task.apiConversationHistory[0].content).toEqual([
165+
{ type: "text", text: "This is the main response content." },
166+
])
167+
})
168+
169+
it("should not modify assistant message when no grounding sources are present", async () => {
170+
const task = new Task({
171+
provider: mockProvider as ClineProvider,
172+
apiConfiguration: mockApiConfiguration,
173+
task: "Test task",
174+
startTask: false,
175+
})
176+
177+
task.apiConversationHistory = []
178+
179+
const assistantMessage = "This is a regular response without any sources."
180+
const mockGroundingSources: any[] = [] // No grounding sources
181+
182+
// Apply the same logic
183+
let cleanAssistantMessage = assistantMessage
184+
if (mockGroundingSources.length > 0) {
185+
cleanAssistantMessage = assistantMessage
186+
.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "")
187+
.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "")
188+
.trim()
189+
}
190+
191+
await (task as any).addToApiConversationHistory({
192+
role: "assistant",
193+
content: [{ type: "text", text: cleanAssistantMessage }],
194+
})
195+
196+
// Message should remain unchanged
197+
expect(task.apiConversationHistory[0].content).toEqual([
198+
{ type: "text", text: "This is a regular response without any sources." },
199+
])
200+
})
201+
202+
it("should handle various grounding source formats", () => {
203+
const testCases = [
204+
{
205+
input: "[1] Source Title: https://example.com\n[2] Another: https://test.com\nMain content here",
206+
expected: "Main content here",
207+
},
208+
{
209+
input: "Content first\n\nSources: [1](https://example.com), [2](https://test.com)",
210+
expected: "Content first",
211+
},
212+
{
213+
input: "Mixed content\n[1] Inline Source: https://inline.com\nMore content\nSource: [1](https://inline.com)",
214+
expected: "Mixed content\n\nMore content",
215+
},
216+
]
217+
218+
testCases.forEach(({ input, expected }) => {
219+
const cleaned = input
220+
.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "")
221+
.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "")
222+
.trim()
223+
expect(cleaned).toBe(expected)
224+
})
225+
})
226+
})

0 commit comments

Comments
 (0)