diff --git a/src/api/providers/__tests__/gemini.test.ts b/src/api/providers/__tests__/gemini.test.ts index 1e536eaecfc..d12c261b790 100644 --- a/src/api/providers/__tests__/gemini.test.ts +++ b/src/api/providers/__tests__/gemini.test.ts @@ -101,10 +101,15 @@ describe("GeminiHandler", () => { }) // Verify the model configuration - expect(mockGetGenerativeModel).toHaveBeenCalledWith({ - model: "gemini-2.0-flash-thinking-exp-1219", - systemInstruction: systemPrompt, - }) + expect(mockGetGenerativeModel).toHaveBeenCalledWith( + { + model: "gemini-2.0-flash-thinking-exp-1219", + systemInstruction: systemPrompt, + }, + { + baseUrl: undefined, + }, + ) // Verify generation config expect(mockGenerateContentStream).toHaveBeenCalledWith( @@ -149,9 +154,14 @@ describe("GeminiHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockGetGenerativeModel).toHaveBeenCalledWith({ - model: "gemini-2.0-flash-thinking-exp-1219", - }) + expect(mockGetGenerativeModel).toHaveBeenCalledWith( + { + model: "gemini-2.0-flash-thinking-exp-1219", + }, + { + baseUrl: undefined, + }, + ) expect(mockGenerateContent).toHaveBeenCalledWith({ contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], generationConfig: { diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 4e522b3fcb9..98117e99a9d 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -19,10 +19,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const model = this.client.getGenerativeModel({ - model: this.getModel().id, - systemInstruction: systemPrompt, - }) + const model = this.client.getGenerativeModel( + { + model: this.getModel().id, + systemInstruction: systemPrompt, + }, + { + baseUrl: this.options.googleGeminiBaseUrl || undefined, + }, + ) const result = await model.generateContentStream({ contents: messages.map(convertAnthropicMessageToGemini), generationConfig: { @@ -57,9 +62,14 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl async completePrompt(prompt: string): Promise { try { - const model = this.client.getGenerativeModel({ - model: this.getModel().id, - }) + const model = this.client.getGenerativeModel( + { + model: this.getModel().id, + }, + { + baseUrl: this.options.googleGeminiBaseUrl || undefined, + }, + ) const result = await model.generateContent({ contents: [{ role: "user", parts: [{ text: prompt }] }], diff --git a/src/exports/roo-code.d.ts b/src/exports/roo-code.d.ts index 4442cdcf34e..ed19fc59ea5 100644 --- a/src/exports/roo-code.d.ts +++ b/src/exports/roo-code.d.ts @@ -154,6 +154,7 @@ export type GlobalStateKey = | "openRouterModelInfo" | "openRouterBaseUrl" | "openRouterUseMiddleOutTransform" + | "googleGeminiBaseUrl" | "allowedCommands" | "soundEnabled" | "soundVolume" diff --git a/src/shared/api.ts b/src/shared/api.ts index d497e0d75f5..38a6c3ec6dc 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -56,6 +56,7 @@ export interface ApiHandlerOptions { lmStudioDraftModelId?: string lmStudioSpeculativeDecodingEnabled?: boolean geminiApiKey?: string + googleGeminiBaseUrl?: string openAiNativeApiKey?: string mistralApiKey?: string mistralCodestralUrl?: string // New option for Codestral URL @@ -115,6 +116,7 @@ export const API_CONFIG_KEYS: GlobalStateKey[] = [ "lmStudioBaseUrl", "lmStudioDraftModelId", "lmStudioSpeculativeDecodingEnabled", + "googleGeminiBaseUrl", "mistralCodestralUrl", "azureApiVersion", "openRouterUseMiddleOutTransform", diff --git a/src/shared/globalState.ts b/src/shared/globalState.ts index 40c4b17f923..9896c376974 100644 --- a/src/shared/globalState.ts +++ b/src/shared/globalState.ts @@ -72,6 +72,7 @@ export const GLOBAL_STATE_KEYS = [ "openRouterModelInfo", "openRouterBaseUrl", "openRouterUseMiddleOutTransform", + "googleGeminiBaseUrl", "allowedCommands", "soundEnabled", "soundVolume", diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 049292767fe..49756796427 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -116,6 +116,9 @@ const ApiOptions = ({ const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl) const [azureApiVersionSelected, setAzureApiVersionSelected] = useState(!!apiConfiguration?.azureApiVersion) const [openRouterBaseUrlSelected, setOpenRouterBaseUrlSelected] = useState(!!apiConfiguration?.openRouterBaseUrl) + const [googleGeminiBaseUrlSelected, setGoogleGeminiBaseUrlSelected] = useState( + !!apiConfiguration?.googleGeminiBaseUrl, + ) const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false) const noTransform = (value: T) => value @@ -646,6 +649,28 @@ const ApiOptions = ({ Get Gemini API Key )} +
+ { + setGoogleGeminiBaseUrlSelected(checked) + + if (!checked) { + setApiConfigurationField("googleGeminiBaseUrl", "") + } + }}> + Use custom base URL + + {googleGeminiBaseUrlSelected && ( + + )} +
)}