Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/api/providers/__tests__/gemini.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,15 @@ describe("GeminiHandler", () => {
})

// Verify the model configuration
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: "gemini-2.0-flash-thinking-exp-1219",
systemInstruction: systemPrompt,
})
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
{
model: "gemini-2.0-flash-thinking-exp-1219",
systemInstruction: systemPrompt,
},
{
baseUrl: undefined,
},
)

// Verify generation config
expect(mockGenerateContentStream).toHaveBeenCalledWith(
Expand Down Expand Up @@ -149,9 +154,14 @@ describe("GeminiHandler", () => {

const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockGetGenerativeModel).toHaveBeenCalledWith({
model: "gemini-2.0-flash-thinking-exp-1219",
})
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
{
model: "gemini-2.0-flash-thinking-exp-1219",
},
{
baseUrl: undefined,
},
)
expect(mockGenerateContent).toHaveBeenCalledWith({
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
generationConfig: {
Expand Down
24 changes: 17 additions & 7 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const model = this.client.getGenerativeModel({
model: this.getModel().id,
systemInstruction: systemPrompt,
})
const model = this.client.getGenerativeModel(
{
model: this.getModel().id,
systemInstruction: systemPrompt,
},
{
baseUrl: this.options.googleGeminiBaseUrl || undefined,
},
)
const result = await model.generateContentStream({
contents: messages.map(convertAnthropicMessageToGemini),
generationConfig: {
Expand Down Expand Up @@ -57,9 +62,14 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl

async completePrompt(prompt: string): Promise<string> {
try {
const model = this.client.getGenerativeModel({
model: this.getModel().id,
})
const model = this.client.getGenerativeModel(
{
model: this.getModel().id,
},
{
baseUrl: this.options.googleGeminiBaseUrl || undefined,
},
)

const result = await model.generateContent({
contents: [{ role: "user", parts: [{ text: prompt }] }],
Expand Down
1 change: 1 addition & 0 deletions src/exports/roo-code.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ export type GlobalStateKey =
| "openRouterModelInfo"
| "openRouterBaseUrl"
| "openRouterUseMiddleOutTransform"
| "googleGeminiBaseUrl"
| "allowedCommands"
| "soundEnabled"
| "soundVolume"
Expand Down
2 changes: 2 additions & 0 deletions src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export interface ApiHandlerOptions {
lmStudioDraftModelId?: string
lmStudioSpeculativeDecodingEnabled?: boolean
geminiApiKey?: string
googleGeminiBaseUrl?: string
openAiNativeApiKey?: string
mistralApiKey?: string
mistralCodestralUrl?: string // New option for Codestral URL
Expand Down Expand Up @@ -115,6 +116,7 @@ export const API_CONFIG_KEYS: GlobalStateKey[] = [
"lmStudioBaseUrl",
"lmStudioDraftModelId",
"lmStudioSpeculativeDecodingEnabled",
"googleGeminiBaseUrl",
"mistralCodestralUrl",
"azureApiVersion",
"openRouterUseMiddleOutTransform",
Expand Down
1 change: 1 addition & 0 deletions src/shared/globalState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export const GLOBAL_STATE_KEYS = [
"openRouterModelInfo",
"openRouterBaseUrl",
"openRouterUseMiddleOutTransform",
"googleGeminiBaseUrl",
"allowedCommands",
"soundEnabled",
"soundVolume",
Expand Down
25 changes: 25 additions & 0 deletions webview-ui/src/components/settings/ApiOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ const ApiOptions = ({
const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl)
const [azureApiVersionSelected, setAzureApiVersionSelected] = useState(!!apiConfiguration?.azureApiVersion)
const [openRouterBaseUrlSelected, setOpenRouterBaseUrlSelected] = useState(!!apiConfiguration?.openRouterBaseUrl)
const [googleGeminiBaseUrlSelected, setGoogleGeminiBaseUrlSelected] = useState(
!!apiConfiguration?.googleGeminiBaseUrl,
)
const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false)

const noTransform = <T,>(value: T) => value
Expand Down Expand Up @@ -646,6 +649,28 @@ const ApiOptions = ({
Get Gemini API Key
</VSCodeButtonLink>
)}
<div>
<Checkbox
checked={googleGeminiBaseUrlSelected}
onChange={(checked: boolean) => {
setGoogleGeminiBaseUrlSelected(checked)

if (!checked) {
setApiConfigurationField("googleGeminiBaseUrl", "")
}
}}>
Use custom base URL
</Checkbox>
{googleGeminiBaseUrlSelected && (
<VSCodeTextField
value={apiConfiguration?.googleGeminiBaseUrl || ""}
type="url"
onInput={handleInputChange("googleGeminiBaseUrl")}
placeholder="https://generativelanguage.googleapis.com"
className="w-full mt-1"
/>
)}
</div>
</>
)}

Expand Down
Loading