Skip to content

Commit e8be1ab

Browse files
committed
fix: update task API configuration when Gemini API key changes
- Update task.apiConfiguration alongside task.api when provider settings change - Ensures Gemini API key is properly refreshed when changed mid-task - Add comprehensive test coverage for API key update scenarios Fixes #7090
1 parent dcbb7a6 commit e8be1ab

File tree

2 files changed

+335
-0
lines changed

2 files changed

+335
-0
lines changed

src/core/webview/ClineProvider.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,7 @@ export class ClineProvider
11441144

11451145
if (task) {
11461146
task.api = buildApiHandler(providerSettings)
1147+
task.apiConfiguration = providerSettings
11471148
}
11481149
} else {
11491150
await this.updateGlobalState("listApiConfigMeta", await this.providerSettingsManager.listConfig())
@@ -1205,6 +1206,7 @@ export class ClineProvider
12051206

12061207
if (task) {
12071208
task.api = buildApiHandler(providerSettings)
1209+
task.apiConfiguration = providerSettings
12081210
}
12091211

12101212
await this.postStateToWebview()
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
import * as os from "os"
2+
import * as path from "path"
3+
import * as vscode from "vscode"
4+
import { describe, it, expect, beforeEach, vi } from "vitest"
5+
6+
import { ClineProvider } from "../ClineProvider"
7+
import { ContextProxy } from "../../config/ContextProxy"
8+
import { Task } from "../../task/Task"
9+
import { buildApiHandler } from "../../../api"
10+
import type { ProviderSettings } from "@roo-code/types"
11+
12+
// Mock dependencies
13+
vi.mock("../../../api", () => ({
14+
buildApiHandler: vi.fn().mockImplementation((config) => ({
15+
getModel: () => ({ id: config.apiModelId || "gemini-1.5-pro" }),
16+
createMessage: vi.fn(),
17+
})),
18+
}))
19+
20+
vi.mock("../../task/Task", () => ({
21+
Task: vi.fn().mockImplementation(function (this: any, options: any) {
22+
this.api = options.apiConfiguration ? buildApiHandler(options.apiConfiguration) : null
23+
this.apiConfiguration = options.apiConfiguration
24+
this.taskId = "test-task-id"
25+
this.providerRef = { deref: () => options.provider }
26+
}),
27+
}))
28+
29+
vi.mock("@roo-code/telemetry", () => ({
30+
TelemetryService: {
31+
instance: {
32+
setProvider: vi.fn(),
33+
},
34+
hasInstance: () => true,
35+
createInstance: vi.fn(),
36+
},
37+
}))
38+
39+
vi.mock("../../config/ProviderSettingsManager", () => ({
40+
ProviderSettingsManager: vi.fn().mockImplementation(() => ({
41+
saveConfig: vi.fn().mockResolvedValue("config-id"),
42+
listConfig: vi.fn().mockResolvedValue([]),
43+
setModeConfig: vi.fn().mockResolvedValue(undefined),
44+
getModeConfigId: vi.fn().mockResolvedValue(undefined),
45+
activateProfile: vi.fn().mockResolvedValue({
46+
name: "test-profile",
47+
id: "test-id",
48+
apiProvider: "gemini",
49+
}),
50+
})),
51+
}))
52+
53+
vi.mock("../../config/CustomModesManager", () => ({
54+
CustomModesManager: vi.fn().mockImplementation(() => ({
55+
getCustomModes: vi.fn().mockResolvedValue([]),
56+
})),
57+
}))
58+
59+
vi.mock("../../../integrations/workspace/WorkspaceTracker", () => ({
60+
default: vi.fn().mockImplementation(() => ({
61+
dispose: vi.fn(),
62+
})),
63+
}))
64+
65+
vi.mock("../../../services/mcp/McpServerManager", () => ({
66+
McpServerManager: {
67+
getInstance: vi.fn().mockResolvedValue(undefined),
68+
unregisterProvider: vi.fn(),
69+
},
70+
}))
71+
72+
vi.mock("../../../services/marketplace", () => ({
73+
MarketplaceManager: vi.fn().mockImplementation(() => ({
74+
cleanup: vi.fn(),
75+
})),
76+
}))
77+
78+
describe("ClineProvider - Gemini API Key Update", () => {
79+
let provider: ClineProvider
80+
let mockExtensionContext: vscode.ExtensionContext
81+
let mockOutputChannel: any
82+
let mockContextProxy: ContextProxy
83+
84+
beforeEach(() => {
85+
// Setup mock extension context
86+
const storageUri = {
87+
fsPath: path.join(os.tmpdir(), "test-storage"),
88+
}
89+
90+
mockExtensionContext = {
91+
globalState: {
92+
get: vi.fn().mockImplementation(() => undefined),
93+
update: vi.fn().mockResolvedValue(undefined),
94+
keys: vi.fn().mockReturnValue([]),
95+
},
96+
globalStorageUri: storageUri,
97+
workspaceState: {
98+
get: vi.fn().mockImplementation(() => undefined),
99+
update: vi.fn().mockResolvedValue(undefined),
100+
keys: vi.fn().mockReturnValue([]),
101+
},
102+
secrets: {
103+
get: vi.fn().mockResolvedValue(undefined),
104+
store: vi.fn().mockResolvedValue(undefined),
105+
delete: vi.fn().mockResolvedValue(undefined),
106+
},
107+
extensionUri: {
108+
fsPath: "/mock/extension/path",
109+
},
110+
extension: {
111+
packageJSON: {
112+
version: "1.0.0",
113+
},
114+
},
115+
} as unknown as vscode.ExtensionContext
116+
117+
// Setup mock output channel
118+
mockOutputChannel = {
119+
appendLine: vi.fn(),
120+
append: vi.fn(),
121+
clear: vi.fn(),
122+
show: vi.fn(),
123+
hide: vi.fn(),
124+
dispose: vi.fn(),
125+
}
126+
127+
// Setup mock context proxy
128+
mockContextProxy = new ContextProxy(mockExtensionContext)
129+
mockContextProxy.setProviderSettings = vi.fn().mockResolvedValue(undefined)
130+
mockContextProxy.setValue = vi.fn().mockResolvedValue(undefined)
131+
mockContextProxy.getValues = vi.fn().mockReturnValue({
132+
mode: "code",
133+
listApiConfigMeta: [],
134+
})
135+
mockContextProxy.getProviderSettings = vi.fn().mockReturnValue({
136+
apiProvider: "gemini",
137+
geminiApiKey: "old-key",
138+
apiModelId: "gemini-1.5-pro",
139+
})
140+
141+
// Create provider instance
142+
provider = new ClineProvider(mockExtensionContext, mockOutputChannel, "sidebar", mockContextProxy)
143+
144+
// Mock provider methods
145+
provider.postMessageToWebview = vi.fn().mockResolvedValue(undefined)
146+
provider.postStateToWebview = vi.fn().mockResolvedValue(undefined)
147+
provider.getState = vi.fn().mockResolvedValue({
148+
mode: "code",
149+
apiConfiguration: {
150+
apiProvider: "gemini",
151+
geminiApiKey: "old-key",
152+
apiModelId: "gemini-1.5-pro",
153+
},
154+
})
155+
// Use public method instead of private one
156+
provider.setValue = vi.fn().mockResolvedValue(undefined)
157+
})
158+
159+
it("should update task API handler when Gemini API key is changed", async () => {
160+
// Create a mock task
161+
const mockTask = new Task({
162+
provider,
163+
apiConfiguration: {
164+
apiProvider: "gemini",
165+
geminiApiKey: "old-key",
166+
apiModelId: "gemini-1.5-pro",
167+
},
168+
task: "test task",
169+
}) as any
170+
171+
// Add the task to the provider's stack
172+
provider["clineStack"] = [mockTask]
173+
174+
// Prepare new provider settings with updated API key
175+
const newProviderSettings: ProviderSettings = {
176+
apiProvider: "gemini",
177+
geminiApiKey: "new-key",
178+
apiModelId: "gemini-1.5-pro",
179+
}
180+
181+
// Call upsertProviderProfile with the new settings
182+
await provider.upsertProviderProfile("test-profile", newProviderSettings, true)
183+
184+
// Verify that buildApiHandler was called with the new settings
185+
expect(buildApiHandler).toHaveBeenCalledWith(newProviderSettings)
186+
187+
// Verify that the task's API handler was updated
188+
expect(mockTask.api).toBeDefined()
189+
expect(mockTask.apiConfiguration).toEqual(newProviderSettings)
190+
191+
// Verify that context proxy was updated with new settings
192+
expect(mockContextProxy.setProviderSettings).toHaveBeenCalledWith(newProviderSettings)
193+
})
194+
195+
it("should update task API configuration when activating a different profile", async () => {
196+
// Create a mock task with initial configuration
197+
const mockTask = new Task({
198+
provider,
199+
apiConfiguration: {
200+
apiProvider: "gemini",
201+
geminiApiKey: "initial-key",
202+
apiModelId: "gemini-1.5-pro",
203+
},
204+
task: "test task",
205+
}) as any
206+
207+
// Add the task to the provider's stack
208+
provider["clineStack"] = [mockTask]
209+
210+
// Mock the provider settings manager's activateProfile method
211+
const newProviderSettings = {
212+
apiProvider: "gemini" as const,
213+
geminiApiKey: "activated-key",
214+
apiModelId: "gemini-1.5-flash",
215+
}
216+
217+
provider["providerSettingsManager"].activateProfile = vi.fn().mockResolvedValue({
218+
name: "activated-profile",
219+
id: "activated-id",
220+
...newProviderSettings,
221+
})
222+
223+
// Call activateProviderProfile
224+
await provider.activateProviderProfile({ name: "activated-profile" })
225+
226+
// Verify that buildApiHandler was called with the activated settings
227+
expect(buildApiHandler).toHaveBeenCalledWith(newProviderSettings)
228+
229+
// Verify that the task's API handler and configuration were updated
230+
expect(mockTask.api).toBeDefined()
231+
expect(mockTask.apiConfiguration).toEqual(newProviderSettings)
232+
233+
// Verify that context proxy was updated
234+
expect(mockContextProxy.setProviderSettings).toHaveBeenCalledWith(newProviderSettings)
235+
})
236+
237+
it("should not update API handler when activate is false", async () => {
238+
// Create a mock task
239+
const mockTask = new Task({
240+
provider,
241+
apiConfiguration: {
242+
apiProvider: "gemini",
243+
geminiApiKey: "old-key",
244+
apiModelId: "gemini-1.5-pro",
245+
},
246+
task: "test task",
247+
}) as any
248+
249+
const originalApi = mockTask.api
250+
const originalConfig = mockTask.apiConfiguration
251+
252+
// Add the task to the provider's stack
253+
provider["clineStack"] = [mockTask]
254+
255+
// Prepare new provider settings
256+
const newProviderSettings: ProviderSettings = {
257+
apiProvider: "gemini",
258+
geminiApiKey: "new-key",
259+
apiModelId: "gemini-1.5-pro",
260+
}
261+
262+
// Call upsertProviderProfile with activate = false
263+
await provider.upsertProviderProfile("test-profile", newProviderSettings, false)
264+
265+
// Verify that the task's API handler was NOT updated
266+
expect(mockTask.api).toBe(originalApi)
267+
expect(mockTask.apiConfiguration).toBe(originalConfig)
268+
269+
// Verify that context proxy was NOT updated
270+
expect(mockContextProxy.setProviderSettings).not.toHaveBeenCalled()
271+
})
272+
273+
it("should handle case when no task is active", async () => {
274+
// Ensure no task is in the stack
275+
provider["clineStack"] = []
276+
277+
// Prepare new provider settings
278+
const newProviderSettings: ProviderSettings = {
279+
apiProvider: "gemini",
280+
geminiApiKey: "new-key",
281+
apiModelId: "gemini-1.5-pro",
282+
}
283+
284+
// Call upsertProviderProfile - should not throw
285+
await expect(provider.upsertProviderProfile("test-profile", newProviderSettings, true)).resolves.toBeDefined()
286+
287+
// Verify that buildApiHandler was still called (for potential future tasks)
288+
expect(buildApiHandler).toHaveBeenCalledWith(newProviderSettings)
289+
290+
// Verify that context proxy was updated
291+
expect(mockContextProxy.setProviderSettings).toHaveBeenCalledWith(newProviderSettings)
292+
})
293+
294+
it("should preserve other provider settings when updating Gemini API key", async () => {
295+
// Create a mock task with additional settings
296+
const initialConfig = {
297+
apiProvider: "gemini" as const,
298+
geminiApiKey: "old-key",
299+
apiModelId: "gemini-1.5-pro",
300+
geminiBaseUrl: "https://custom.gemini.api",
301+
temperature: 0.7,
302+
maxTokens: 4096,
303+
}
304+
305+
const mockTask = new Task({
306+
provider,
307+
apiConfiguration: initialConfig,
308+
task: "test task",
309+
}) as any
310+
311+
// Add the task to the provider's stack
312+
provider["clineStack"] = [mockTask]
313+
314+
// Update only the API key, preserving other settings
315+
const newProviderSettings: ProviderSettings = {
316+
...initialConfig,
317+
geminiApiKey: "new-key",
318+
}
319+
320+
// Call upsertProviderProfile
321+
await provider.upsertProviderProfile("test-profile", newProviderSettings, true)
322+
323+
// Verify that all settings are preserved except the API key
324+
expect(mockTask.apiConfiguration).toEqual({
325+
apiProvider: "gemini",
326+
geminiApiKey: "new-key",
327+
apiModelId: "gemini-1.5-pro",
328+
geminiBaseUrl: "https://custom.gemini.api",
329+
temperature: 0.7,
330+
maxTokens: 4096,
331+
})
332+
})
333+
})

0 commit comments

Comments
 (0)