diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 5262e7602d68..328274eb07c3 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -429,6 +429,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv lmStudioSchema.merge(z.object({ apiProvider: z.literal("lmstudio") })), geminiSchema.merge(z.object({ apiProvider: z.literal("gemini") })), geminiCliSchema.merge(z.object({ apiProvider: z.literal("gemini-cli") })), + geminiCliSchema.merge(z.object({ apiProvider: z.literal("gemini-cli") })), openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })), mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })), deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })), @@ -470,6 +471,7 @@ export const providerSettingsSchema = z.object({ ...lmStudioSchema.shape, ...geminiSchema.shape, ...geminiCliSchema.shape, + ...geminiCliSchema.shape, ...openAiNativeSchema.shape, ...mistralSchema.shape, ...deepSeekSchema.shape, diff --git a/packages/types/src/providers/gemini-cli.ts b/packages/types/src/providers/gemini-cli.ts new file mode 100644 index 000000000000..4ef498220e5f --- /dev/null +++ b/packages/types/src/providers/gemini-cli.ts @@ -0,0 +1,110 @@ +import type { ModelInfo } from "../model.js" + +// Gemini CLI models with free tier pricing (all $0) +export type GeminiCliModelId = keyof typeof geminiCliModels + +export const geminiCliDefaultModelId: GeminiCliModelId = "gemini-2.0-flash-001" + +export const geminiCliModels = { + "gemini-2.0-flash-001": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.0-flash-thinking-exp-01-21": { + maxTokens: 65_536, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.0-flash-thinking-exp-1219": { + maxTokens: 8192, + contextWindow: 32_767, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.0-flash-exp": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-1.5-flash-002": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-1.5-flash-exp-0827": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-1.5-flash-8b-exp-0827": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-1.5-pro-002": { + maxTokens: 8192, + contextWindow: 2_097_152, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-1.5-pro-exp-0827": { + maxTokens: 8192, + contextWindow: 2_097_152, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-exp-1206": { + maxTokens: 8192, + contextWindow: 2_097_152, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.5-flash": { + maxTokens: 64_000, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + maxThinkingTokens: 24_576, + supportsReasoningBudget: true, + }, + "gemini-2.5-pro": { + maxTokens: 64_000, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + maxThinkingTokens: 32_768, + supportsReasoningBudget: true, + requiredReasoningBudget: true, + }, +} as const satisfies Record diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index 21e43aaa99a6..fb58d052c678 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -8,6 +8,7 @@ export * from "./doubao.js" export * from "./featherless.js" export * from "./fireworks.js" export * from "./gemini.js" +export * from "./gemini-cli.js" export * from "./glama.js" export * from "./groq.js" export * from "./huggingface.js" diff --git a/src/api/index.ts b/src/api/index.ts index ac0096767624..3a4afa63a461 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -15,6 +15,7 @@ import { OpenAiHandler, LmStudioHandler, GeminiHandler, + GeminiCliHandler, OpenAiNativeHandler, DeepSeekHandler, MoonshotHandler, @@ -113,6 +114,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new LmStudioHandler(options) case "gemini": return new GeminiHandler(options) + case "gemini-cli": + return new GeminiCliHandler(options) case "openai-native": return new OpenAiNativeHandler(options) case "deepseek": diff --git a/src/api/providers/__tests__/gemini-cli.spec.ts b/src/api/providers/__tests__/gemini-cli.spec.ts new file mode 100644 index 000000000000..a84e940e969b --- /dev/null +++ b/src/api/providers/__tests__/gemini-cli.spec.ts @@ -0,0 +1,329 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { GeminiCliHandler } from "../gemini-cli" +import { geminiCliDefaultModelId, geminiCliModels } from "@roo-code/types" +import * as fs from "fs/promises" +import axios from "axios" + +vi.mock("fs/promises") +vi.mock("axios") +vi.mock("google-auth-library", () => ({ + OAuth2Client: vi.fn().mockImplementation(() => ({ + setCredentials: vi.fn(), + refreshAccessToken: vi.fn().mockResolvedValue({ + credentials: { + access_token: "refreshed-token", + refresh_token: "refresh-token", + token_type: "Bearer", + expiry_date: Date.now() + 3600 * 1000, + }, + }), + request: vi.fn(), + })), +})) + +describe("GeminiCliHandler", () => { + let handler: GeminiCliHandler + const mockCredentials = { + access_token: "test-access-token", + refresh_token: "test-refresh-token", + token_type: "Bearer", + expiry_date: Date.now() + 3600 * 1000, + } + + beforeEach(() => { + vi.clearAllMocks() + ;(fs.readFile as any).mockResolvedValue(JSON.stringify(mockCredentials)) + ;(fs.writeFile as any).mockResolvedValue(undefined) + + // Set up default mock + ;(axios.post as any).mockResolvedValue({ + data: {}, + }) + + handler = new GeminiCliHandler({ + apiModelId: geminiCliDefaultModelId, + }) + + // Set up default mock for OAuth2Client request + handler["authClient"].request = vi.fn().mockResolvedValue({ + data: {}, + }) + + // Mock the discoverProjectId to avoid real API calls in tests + handler["projectId"] = "test-project-123" + vi.spyOn(handler as any, "discoverProjectId").mockResolvedValue("test-project-123") + }) + + describe("constructor", () => { + it("should initialize with provided config", () => { + expect(handler["options"].apiModelId).toBe(geminiCliDefaultModelId) + }) + }) + + describe("getModel", () => { + it("should return correct model info", () => { + const modelInfo = handler.getModel() + expect(modelInfo.id).toBe(geminiCliDefaultModelId) + expect(modelInfo.info).toBeDefined() + expect(modelInfo.info.inputPrice).toBe(0) + expect(modelInfo.info.outputPrice).toBe(0) + }) + + it("should return default model if invalid model specified", () => { + const invalidHandler = new GeminiCliHandler({ + apiModelId: "invalid-model", + }) + const modelInfo = invalidHandler.getModel() + expect(modelInfo.id).toBe(geminiCliDefaultModelId) + }) + + it("should handle :thinking suffix", () => { + const thinkingHandler = new GeminiCliHandler({ + apiModelId: "gemini-2.5-pro:thinking", + }) + const modelInfo = thinkingHandler.getModel() + // The :thinking suffix should be removed from the ID + expect(modelInfo.id).toBe("gemini-2.5-pro") + // But the model should still have reasoning support + expect(modelInfo.info.supportsReasoningBudget).toBe(true) + expect(modelInfo.info.requiredReasoningBudget).toBe(true) + }) + }) + + describe("OAuth authentication", () => { + it("should load OAuth credentials from default path", async () => { + await handler["loadOAuthCredentials"]() + expect(fs.readFile).toHaveBeenCalledWith(expect.stringMatching(/\.gemini[/\\]oauth_creds\.json$/), "utf-8") + }) + + it("should load OAuth credentials from custom path", async () => { + const customHandler = new GeminiCliHandler({ + apiModelId: geminiCliDefaultModelId, + geminiCliOAuthPath: "/custom/path/oauth.json", + }) + await customHandler["loadOAuthCredentials"]() + expect(fs.readFile).toHaveBeenCalledWith("/custom/path/oauth.json", "utf-8") + }) + + it("should refresh expired tokens", async () => { + const expiredCredentials = { + ...mockCredentials, + expiry_date: Date.now() - 1000, // Expired + } + ;(fs.readFile as any).mockResolvedValueOnce(JSON.stringify(expiredCredentials)) + + await handler["ensureAuthenticated"]() + + expect(handler["authClient"].refreshAccessToken).toHaveBeenCalled() + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringMatching(/\.gemini[/\\]oauth_creds\.json$/), + expect.stringContaining("refreshed-token"), + ) + }) + + it("should throw error if credentials file not found", async () => { + ;(fs.readFile as any).mockRejectedValueOnce(new Error("ENOENT")) + + await expect(handler["loadOAuthCredentials"]()).rejects.toThrow("errors.geminiCli.oauthLoadFailed") + }) + }) + + describe("project ID discovery", () => { + it("should use provided project ID", async () => { + const customHandler = new GeminiCliHandler({ + apiModelId: geminiCliDefaultModelId, + geminiCliProjectId: "custom-project", + }) + + const projectId = await customHandler["discoverProjectId"]() + expect(projectId).toBe("custom-project") + expect(customHandler["projectId"]).toBe("custom-project") + }) + + it("should discover project ID through API", async () => { + // Create a new handler without the mocked discoverProjectId + const testHandler = new GeminiCliHandler({ + apiModelId: geminiCliDefaultModelId, + }) + testHandler["authClient"].request = vi.fn().mockResolvedValue({ + data: {}, + }) + + // Mock the callEndpoint method + testHandler["callEndpoint"] = vi.fn().mockResolvedValueOnce({ + cloudaicompanionProject: "discovered-project-123", + }) + + const projectId = await testHandler["discoverProjectId"]() + expect(projectId).toBe("discovered-project-123") + expect(testHandler["projectId"]).toBe("discovered-project-123") + }) + + it("should onboard user if no existing project", async () => { + // Create a new handler without the mocked discoverProjectId + const testHandler = new GeminiCliHandler({ + apiModelId: geminiCliDefaultModelId, + }) + testHandler["authClient"].request = vi.fn().mockResolvedValue({ + data: {}, + }) + + // Mock the callEndpoint method + testHandler["callEndpoint"] = vi + .fn() + .mockResolvedValueOnce({ + allowedTiers: [{ id: "free-tier", isDefault: true }], + }) + .mockResolvedValueOnce({ + done: false, + }) + .mockResolvedValueOnce({ + done: true, + response: { + cloudaicompanionProject: { + id: "onboarded-project-456", + }, + }, + }) + + const projectId = await testHandler["discoverProjectId"]() + expect(projectId).toBe("onboarded-project-456") + expect(testHandler["projectId"]).toBe("onboarded-project-456") + expect(testHandler["callEndpoint"]).toHaveBeenCalledTimes(3) + }) + }) + + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + handler["authClient"].request = vi.fn().mockResolvedValue({ + data: { + candidates: [ + { + content: { + parts: [{ text: "Test response" }], + }, + }, + ], + }, + }) + + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") + }) + + it("should handle empty response", async () => { + handler["authClient"].request = vi.fn().mockResolvedValue({ + data: { + candidates: [], + }, + }) + + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + + it("should filter out thinking parts", async () => { + handler["authClient"].request = vi.fn().mockResolvedValue({ + data: { + candidates: [ + { + content: { + parts: [{ text: "Thinking...", thought: true }, { text: "Actual response" }], + }, + }, + ], + }, + }) + + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Actual response") + }) + + it("should handle API errors", async () => { + handler["authClient"].request = vi.fn().mockRejectedValue(new Error("API Error")) + + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("errors.geminiCli.completionError") + }) + }) + + describe("createMessage streaming", () => { + it("should handle streaming response with reasoning", async () => { + // Create a mock Node.js readable stream + const { Readable } = require("stream") + const mockStream = new Readable({ + read() { + this.push('data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}\n\n') + this.push( + 'data: {"candidates":[{"content":{"parts":[{"thought":true,"text":"thinking..."}]}}]}\n\n', + ) + this.push( + 'data: {"candidates":[{"content":{"parts":[{"text":" world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}\n\n', + ) + this.push("data: [DONE]\n\n") + this.push(null) // End the stream + }, + }) + + handler["authClient"].request = vi.fn().mockResolvedValue({ + data: mockStream, + }) + + const stream = handler.createMessage("System", []) + const chunks: any[] = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Check we got the expected chunks + expect(chunks).toHaveLength(4) // 2 text chunks, 1 reasoning chunk, 1 usage chunk + + // Filter out only text chunks (not reasoning chunks) + const textChunks = chunks.filter((c) => c.type === "text").map((c) => c.text) + expect(textChunks).toEqual(["Hello", " world"]) + + // Check reasoning chunk + const reasoningChunks = chunks.filter((c) => c.type === "reasoning") + expect(reasoningChunks).toHaveLength(1) + expect(reasoningChunks[0].text).toBe("thinking...") + + // Check usage chunk + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks).toHaveLength(1) + expect(usageChunks[0]).toMatchObject({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + totalCost: 0, + }) + }) + + it("should handle rate limit errors", async () => { + handler["authClient"].request = vi.fn().mockRejectedValue({ + response: { + status: 429, + data: { error: { message: "Rate limit exceeded" } }, + }, + }) + + const stream = handler.createMessage("System", []) + + await expect(async () => { + for await (const _chunk of stream) { + // Should throw before yielding + } + }).rejects.toThrow("errors.geminiCli.rateLimitExceeded") + }) + }) + + describe("countTokens", () => { + it("should fall back to base provider implementation", async () => { + const content = [{ type: "text" as const, text: "Hello world" }] + const tokenCount = await handler.countTokens(content) + + // Should return a number (tiktoken fallback) + expect(typeof tokenCount).toBe("number") + expect(tokenCount).toBeGreaterThan(0) + }) + }) +}) diff --git a/src/api/providers/gemini-cli.ts b/src/api/providers/gemini-cli.ts new file mode 100644 index 000000000000..6e265511e4d9 --- /dev/null +++ b/src/api/providers/gemini-cli.ts @@ -0,0 +1,419 @@ +import type { Anthropic } from "@anthropic-ai/sdk" +import { OAuth2Client } from "google-auth-library" +import * as fs from "fs/promises" +import * as path from "path" +import * as os from "os" +import axios from "axios" + +import { type ModelInfo, type GeminiCliModelId, geminiCliDefaultModelId, geminiCliModels } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../shared/api" +import { t } from "../../i18n" + +import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" +import type { ApiStream } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { BaseProvider } from "./base-provider" + +// OAuth2 Configuration (from Cline implementation) +const OAUTH_CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" +const OAUTH_CLIENT_SECRET = "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" +const OAUTH_REDIRECT_URI = "http://localhost:45289" + +// Code Assist API Configuration +const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com" +const CODE_ASSIST_API_VERSION = "v1internal" + +interface OAuthCredentials { + access_token: string + refresh_token: string + token_type: string + expiry_date: number +} + +export class GeminiCliHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + private authClient: OAuth2Client + private projectId: string | null = null + private credentials: OAuthCredentials | null = null + + constructor(options: ApiHandlerOptions) { + super() + this.options = options + + // Initialize OAuth2 client + this.authClient = new OAuth2Client(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OAUTH_REDIRECT_URI) + } + + private async loadOAuthCredentials(): Promise { + try { + const credPath = this.options.geminiCliOAuthPath || path.join(os.homedir(), ".gemini", "oauth_creds.json") + const credData = await fs.readFile(credPath, "utf-8") + this.credentials = JSON.parse(credData) + + // Set credentials on the OAuth2 client + if (this.credentials) { + this.authClient.setCredentials({ + access_token: this.credentials.access_token, + refresh_token: this.credentials.refresh_token, + expiry_date: this.credentials.expiry_date, + }) + } + } catch (error) { + throw new Error(t("common:errors.geminiCli.oauthLoadFailed", { error })) + } + } + + private async ensureAuthenticated(): Promise { + if (!this.credentials) { + await this.loadOAuthCredentials() + } + + // Check if token needs refresh + if (this.credentials && this.credentials.expiry_date < Date.now()) { + try { + const { credentials } = await this.authClient.refreshAccessToken() + if (credentials.access_token) { + this.credentials = { + access_token: credentials.access_token!, + refresh_token: credentials.refresh_token || this.credentials.refresh_token, + token_type: credentials.token_type || "Bearer", + expiry_date: credentials.expiry_date || Date.now() + 3600 * 1000, + } + // Optionally save refreshed credentials back to file + const credPath = + this.options.geminiCliOAuthPath || path.join(os.homedir(), ".gemini", "oauth_creds.json") + await fs.writeFile(credPath, JSON.stringify(this.credentials, null, 2)) + } + } catch (error) { + throw new Error(t("common:errors.geminiCli.tokenRefreshFailed", { error })) + } + } + } + + /** + * Call a Code Assist API endpoint + */ + private async callEndpoint(method: string, body: any, retryAuth: boolean = true): Promise { + try { + const res = await this.authClient.request({ + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, + method: "POST", + headers: { + "Content-Type": "application/json", + }, + responseType: "json", + data: JSON.stringify(body), + }) + return res.data + } catch (error: any) { + console.error(`[GeminiCLI] Error calling ${method}:`, error) + console.error(`[GeminiCLI] Error response:`, error.response?.data) + console.error(`[GeminiCLI] Error status:`, error.response?.status) + console.error(`[GeminiCLI] Error message:`, error.message) + + // If we get a 401 and haven't retried yet, try refreshing auth + if (error.response?.status === 401 && retryAuth) { + await this.ensureAuthenticated() // This will refresh the token + return this.callEndpoint(method, body, false) // Retry without further auth retries + } + + throw error + } + } + + /** + * Discover or retrieve the project ID + */ + private async discoverProjectId(): Promise { + // If we already have a project ID, use it + if (this.options.geminiCliProjectId) { + this.projectId = this.options.geminiCliProjectId + return this.projectId + } + + // If we've already discovered it, return it + if (this.projectId) { + return this.projectId + } + + // Start with a default project ID (can be anything for personal OAuth) + const initialProjectId = "default" + + // Prepare client metadata + const clientMetadata = { + ideType: "IDE_UNSPECIFIED", + platform: "PLATFORM_UNSPECIFIED", + pluginType: "GEMINI", + duetProject: initialProjectId, + } + + try { + // Call loadCodeAssist to discover the actual project ID + const loadRequest = { + cloudaicompanionProject: initialProjectId, + metadata: clientMetadata, + } + + const loadResponse = await this.callEndpoint("loadCodeAssist", loadRequest) + + // Check if we already have a project ID from the response + if (loadResponse.cloudaicompanionProject) { + this.projectId = loadResponse.cloudaicompanionProject + return this.projectId as string + } + + // If no existing project, we need to onboard + const defaultTier = loadResponse.allowedTiers?.find((tier: any) => tier.isDefault) + const tierId = defaultTier?.id || "free-tier" + + const onboardRequest = { + tierId: tierId, + cloudaicompanionProject: initialProjectId, + metadata: clientMetadata, + } + + let lroResponse = await this.callEndpoint("onboardUser", onboardRequest) + + // Poll until operation is complete with timeout protection + const MAX_RETRIES = 30 // Maximum number of retries (60 seconds total) + let retryCount = 0 + + while (!lroResponse.done && retryCount < MAX_RETRIES) { + await new Promise((resolve) => setTimeout(resolve, 2000)) + lroResponse = await this.callEndpoint("onboardUser", onboardRequest) + retryCount++ + } + + if (!lroResponse.done) { + throw new Error(t("common:errors.geminiCli.onboardingTimeout")) + } + + const discoveredProjectId = lroResponse.response?.cloudaicompanionProject?.id || initialProjectId + this.projectId = discoveredProjectId + return this.projectId as string + } catch (error: any) { + console.error("Failed to discover project ID:", error.response?.data || error.message) + throw new Error(t("common:errors.geminiCli.projectDiscoveryFailed")) + } + } + + /** + * Parse Server-Sent Events from a stream + */ + private async *parseSSEStream(stream: NodeJS.ReadableStream): AsyncGenerator { + let buffer = "" + + for await (const chunk of stream) { + buffer += chunk.toString() + const lines = buffer.split("\n") + buffer = lines.pop() || "" + + for (const line of lines) { + if (line.startsWith("data: ")) { + const data = line.slice(6).trim() + if (data === "[DONE]") continue + + try { + const parsed = JSON.parse(data) + yield parsed + } catch (e) { + console.error("Error parsing SSE data:", e) + } + } + } + } + } + + async *createMessage( + systemInstruction: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + await this.ensureAuthenticated() + const projectId = await this.discoverProjectId() + + const { id: model, info, reasoning: thinkingConfig, maxTokens } = this.getModel() + + // Convert messages to Gemini format + const contents = messages.map(convertAnthropicMessageToGemini) + + // Prepare request body for Code Assist API - matching Cline's structure + const requestBody: any = { + model: model, + project: projectId, + request: { + contents: [ + { + role: "user", + parts: [{ text: systemInstruction }], + }, + ...contents, + ], + generationConfig: { + temperature: this.options.modelTemperature ?? 0.7, + maxOutputTokens: this.options.modelMaxTokens ?? maxTokens ?? 8192, + }, + }, + } + + // Add thinking config if applicable + if (thinkingConfig) { + requestBody.request.generationConfig.thinkingConfig = thinkingConfig + } + + try { + // Call Code Assist streaming endpoint using OAuth2Client + const response = await this.authClient.request({ + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:streamGenerateContent`, + method: "POST", + params: { alt: "sse" }, + headers: { + "Content-Type": "application/json", + }, + responseType: "stream", + data: JSON.stringify(requestBody), + }) + + // Process the SSE stream + let lastUsageMetadata: any = undefined + + for await (const jsonData of this.parseSSEStream(response.data as NodeJS.ReadableStream)) { + // Extract content from the response + const responseData = jsonData.response || jsonData + const candidate = responseData.candidates?.[0] + + if (candidate?.content?.parts) { + for (const part of candidate.content.parts) { + if (part.text) { + // Check if this is a thinking/reasoning part + if (part.thought === true) { + yield { + type: "reasoning", + text: part.text, + } + } else { + yield { + type: "text", + text: part.text, + } + } + } + } + } + + // Store usage metadata for final reporting + if (responseData.usageMetadata) { + lastUsageMetadata = responseData.usageMetadata + } + + // Check if this is the final chunk + if (candidate?.finishReason) { + break + } + } + + // Yield final usage information + if (lastUsageMetadata) { + const inputTokens = lastUsageMetadata.promptTokenCount ?? 0 + const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0 + const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount + const reasoningTokens = lastUsageMetadata.thoughtsTokenCount + + yield { + type: "usage", + inputTokens, + outputTokens, + cacheReadTokens, + reasoningTokens, + totalCost: 0, // Free tier - all costs are 0 + } + } + } catch (error: any) { + console.error("[GeminiCLI] API Error:", error.response?.status, error.response?.statusText) + console.error("[GeminiCLI] Error Response:", error.response?.data) + + if (error.response?.status === 429) { + throw new Error(t("common:errors.geminiCli.rateLimitExceeded")) + } + if (error.response?.status === 400) { + throw new Error( + t("common:errors.geminiCli.badRequest", { + details: JSON.stringify(error.response?.data) || error.message, + }), + ) + } + throw new Error(t("common:errors.geminiCli.apiError", { error: error.message })) + } + } + + override getModel() { + const modelId = this.options.apiModelId + // Handle :thinking suffix before checking if model exists + const baseModelId = modelId?.endsWith(":thinking") ? modelId.replace(":thinking", "") : modelId + let id = + baseModelId && baseModelId in geminiCliModels ? (baseModelId as GeminiCliModelId) : geminiCliDefaultModelId + const info: ModelInfo = geminiCliModels[id] + const params = getModelParams({ format: "gemini", modelId: id, model: info, settings: this.options }) + + // Return the cleaned model ID + return { id, info, ...params } + } + + async completePrompt(prompt: string): Promise { + await this.ensureAuthenticated() + const projectId = await this.discoverProjectId() + + try { + const { id: model } = this.getModel() + + const requestBody = { + model: model, + project: projectId, + request: { + contents: [{ role: "user", parts: [{ text: prompt }] }], + generationConfig: { + temperature: this.options.modelTemperature ?? 0.7, + }, + }, + } + + const response = await this.authClient.request({ + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:generateContent`, + method: "POST", + headers: { + "Content-Type": "application/json", + }, + data: JSON.stringify(requestBody), + }) + + // Extract text from response + const responseData = response.data as any + if (responseData.candidates && responseData.candidates.length > 0) { + const candidate = responseData.candidates[0] + if (candidate.content && candidate.content.parts) { + const textParts = candidate.content.parts + .filter((part: any) => part.text && !part.thought) + .map((part: any) => part.text) + .join("") + return textParts + } + } + + return "" + } catch (error) { + if (error instanceof Error) { + throw new Error(t("common:errors.geminiCli.completionError", { error: error.message })) + } + throw error + } + } + + override async countTokens(content: Array): Promise { + // For OAuth/free tier, we can't use the token counting API + // Fall back to the base provider's tiktoken implementation + return super.countTokens(content) + } +} diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index 85d877b6bc78..a315f61eddbd 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -9,6 +9,7 @@ export { DoubaoHandler } from "./doubao" export { MoonshotHandler } from "./moonshot" export { FakeAIHandler } from "./fake-ai" export { GeminiHandler } from "./gemini" +export { GeminiCliHandler } from "./gemini-cli" export { GlamaHandler } from "./glama" export { GroqHandler } from "./groq" export { HuggingFaceHandler } from "./huggingface" diff --git a/src/i18n/locales/en/common.json b/src/i18n/locales/en/common.json index 3a613cc1c21f..d3004c2f29a7 100644 --- a/src/i18n/locales/en/common.json +++ b/src/i18n/locales/en/common.json @@ -121,7 +121,17 @@ "manual_url_no_query": "Invalid callback URL: missing query parameters", "manual_url_missing_params": "Invalid callback URL: missing required parameters (code and state)", "manual_url_auth_failed": "Manual URL authentication failed", - "manual_url_auth_error": "Authentication failed" + "manual_url_auth_error": "Authentication failed", + "geminiCli": { + "oauthLoadFailed": "Failed to load OAuth credentials. Please authenticate first: {{error}}", + "tokenRefreshFailed": "Failed to refresh OAuth token: {{error}}", + "onboardingTimeout": "Onboarding operation timed out after 60 seconds. Please try again later.", + "projectDiscoveryFailed": "Could not discover project ID. Make sure you're authenticated with 'gemini auth'.", + "rateLimitExceeded": "Rate limit exceeded. Free tier limits have been reached.", + "badRequest": "Bad request: {{details}}", + "apiError": "Gemini CLI API error: {{error}}", + "completionError": "Gemini CLI completion error: {{error}}" + } }, "warnings": { "no_terminal_content": "No terminal content selected", diff --git a/src/shared/checkExistApiConfig.ts b/src/shared/checkExistApiConfig.ts index 4b9af08d5afe..61910aadc456 100644 --- a/src/shared/checkExistApiConfig.ts +++ b/src/shared/checkExistApiConfig.ts @@ -5,11 +5,8 @@ export function checkExistKey(config: ProviderSettings | undefined) { return false } - // Special case for human-relay, fake-ai, claude-code, qwen-code, and roo providers which don't need any configuration. - if ( - config.apiProvider && - ["human-relay", "fake-ai", "claude-code", "qwen-code", "roo"].includes(config.apiProvider) - ) { + // Special case for human-relay, fake-ai, and claude-code providers which don't need any configuration. + if (config.apiProvider && ["human-relay", "fake-ai", "claude-code"].includes(config.apiProvider)) { return true } diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 37c1c286b983..671902e03200 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -19,6 +19,7 @@ import { claudeCodeDefaultModelId, qwenCodeDefaultModelId, geminiDefaultModelId, + geminiCliDefaultModelId, deepSeekDefaultModelId, moonshotDefaultModelId, mistralDefaultModelId, @@ -71,6 +72,7 @@ import { DeepSeek, Doubao, Gemini, + GeminiCli, Glama, Groq, HuggingFace, @@ -332,6 +334,7 @@ const ApiOptions = ({ "qwen-code": { field: "apiModelId", default: qwenCodeDefaultModelId }, "openai-native": { field: "apiModelId", default: openAiNativeDefaultModelId }, gemini: { field: "apiModelId", default: geminiDefaultModelId }, + "gemini-cli": { field: "apiModelId", default: geminiCliDefaultModelId }, deepseek: { field: "apiModelId", default: deepSeekDefaultModelId }, doubao: { field: "apiModelId", default: doubaoDefaultModelId }, moonshot: { field: "apiModelId", default: moonshotDefaultModelId }, @@ -558,6 +561,10 @@ const ApiOptions = ({ /> )} + {selectedProvider === "gemini-cli" && ( + + )} + {selectedProvider === "openai" && ( void +} + +export const GeminiCli = ({ apiConfiguration, setApiConfigurationField }: GeminiCliProps) => { + const { t } = useAppTranslation() + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + return ( + <> + + + +
+ {t("settings:providers.geminiCli.oauthPathDescription")} +
+ +
+ {t("settings:providers.geminiCli.description")} +
+ +
+ {t("settings:providers.geminiCli.instructions")}{" "} + gemini{" "} + {t("settings:providers.geminiCli.instructionsContinued")} +
+ + + {t("settings:providers.geminiCli.setupLink")} + + +
+
+ + {t("settings:providers.geminiCli.requirementsTitle")} +
+
    +
  • {t("settings:providers.geminiCli.requirement1")}
  • +
  • {t("settings:providers.geminiCli.requirement2")}
  • +
  • {t("settings:providers.geminiCli.requirement3")}
  • +
  • {t("settings:providers.geminiCli.requirement4")}
  • +
  • {t("settings:providers.geminiCli.requirement5")}
  • +
+
+ +
+ + + {t("settings:providers.geminiCli.freeAccess")} + +
+ + ) +} diff --git a/webview-ui/src/components/settings/providers/__tests__/GeminiCli.spec.tsx b/webview-ui/src/components/settings/providers/__tests__/GeminiCli.spec.tsx new file mode 100644 index 000000000000..dc9c82f200d9 --- /dev/null +++ b/webview-ui/src/components/settings/providers/__tests__/GeminiCli.spec.tsx @@ -0,0 +1,169 @@ +import { render, screen, fireEvent } from "@testing-library/react" +import { describe, it, expect, vi } from "vitest" + +import type { ProviderSettings } from "@roo-code/types" + +import { GeminiCli } from "../GeminiCli" + +// Mock the translation hook +vi.mock("@src/i18n/TranslationContext", () => ({ + useAppTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock VSCodeLink to render as a regular anchor tag +vi.mock("@vscode/webview-ui-toolkit/react", async () => { + const actual = await vi.importActual("@vscode/webview-ui-toolkit/react") + return { + ...actual, + VSCodeLink: ({ children, href, ...props }: any) => ( + + {children} + + ), + } +}) + +describe("GeminiCli", () => { + const mockSetApiConfigurationField = vi.fn() + const defaultProps = { + apiConfiguration: {} as ProviderSettings, + setApiConfigurationField: mockSetApiConfigurationField, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + it("renders all required elements", () => { + render() + + // Check for OAuth path input + expect(screen.getByText("settings:providers.geminiCli.oauthPath")).toBeInTheDocument() + expect(screen.getByPlaceholderText("~/.gemini/oauth_creds.json")).toBeInTheDocument() + + // Check for description text + expect(screen.getByText("settings:providers.geminiCli.description")).toBeInTheDocument() + + // Check for instructions - they're in the same div but broken up by the code element + // Find all elements that contain the instruction parts + const instructionsDivs = screen.getAllByText((_content, element) => { + // Check if this element contains all the expected text parts + const fullText = element?.textContent || "" + return ( + fullText.includes("settings:providers.geminiCli.instructions") && + fullText.includes("gemini") && + fullText.includes("settings:providers.geminiCli.instructionsContinued") + ) + }) + // Find the div with the correct classes + const instructionsDiv = instructionsDivs.find( + (div) => + div.classList.contains("text-sm") && + div.classList.contains("text-vscode-descriptionForeground") && + div.classList.contains("mt-2"), + ) + expect(instructionsDiv).toBeDefined() + expect(instructionsDiv).toBeInTheDocument() + + // Also verify the code element exists + const codeElement = screen.getByText("gemini") + expect(codeElement).toBeInTheDocument() + expect(codeElement.tagName).toBe("CODE") + + // Check for setup link + expect(screen.getByText("settings:providers.geminiCli.setupLink")).toBeInTheDocument() + + // Check for requirements + expect(screen.getByText("settings:providers.geminiCli.requirementsTitle")).toBeInTheDocument() + expect(screen.getByText("settings:providers.geminiCli.requirement1")).toBeInTheDocument() + expect(screen.getByText("settings:providers.geminiCli.requirement2")).toBeInTheDocument() + expect(screen.getByText("settings:providers.geminiCli.requirement3")).toBeInTheDocument() + expect(screen.getByText("settings:providers.geminiCli.requirement4")).toBeInTheDocument() + expect(screen.getByText("settings:providers.geminiCli.requirement5")).toBeInTheDocument() + + // Check for free access note + expect(screen.getByText("settings:providers.geminiCli.freeAccess")).toBeInTheDocument() + }) + + it("displays OAuth path from configuration", () => { + const apiConfiguration: ProviderSettings = { + geminiCliOAuthPath: "/custom/path/oauth.json", + } + + render() + + const oauthInput = screen.getByDisplayValue("/custom/path/oauth.json") + expect(oauthInput).toBeInTheDocument() + }) + + it("calls setApiConfigurationField when OAuth path is changed", () => { + render() + + const oauthInput = screen.getByPlaceholderText("~/.gemini/oauth_creds.json") + + // Simulate input event with VSCodeTextField + fireEvent.input(oauthInput, { target: { value: "/new/path.json" } }) + + // Check that setApiConfigurationField was called + expect(mockSetApiConfigurationField).toHaveBeenCalledWith("geminiCliOAuthPath", "/new/path.json") + }) + + it("renders setup link with correct href", () => { + render() + + const setupLink = screen.getByText("settings:providers.geminiCli.setupLink") + expect(setupLink).toHaveAttribute( + "href", + "https://github.com/google-gemini/gemini-cli?tab=readme-ov-file#quickstart", + ) + }) + + it("shows OAuth path description", () => { + render() + + expect(screen.getByText("settings:providers.geminiCli.oauthPathDescription")).toBeInTheDocument() + }) + + it("renders all requirements in a list", () => { + render() + + const listItems = screen.getAllByRole("listitem") + expect(listItems).toHaveLength(5) + expect(listItems[0]).toHaveTextContent("settings:providers.geminiCli.requirement1") + expect(listItems[1]).toHaveTextContent("settings:providers.geminiCli.requirement2") + expect(listItems[2]).toHaveTextContent("settings:providers.geminiCli.requirement3") + expect(listItems[3]).toHaveTextContent("settings:providers.geminiCli.requirement4") + expect(listItems[4]).toHaveTextContent("settings:providers.geminiCli.requirement5") + }) + + it("applies correct styling classes", () => { + render() + + // Check for styled warning box + const warningBox = screen.getByText("settings:providers.geminiCli.requirementsTitle").closest("div.mt-3") + expect(warningBox).toHaveClass("bg-vscode-editorWidget-background") + expect(warningBox).toHaveClass("border-vscode-editorWidget-border") + expect(warningBox).toHaveClass("rounded") + expect(warningBox).toHaveClass("p-3") + + // Check for warning icon + const warningIcon = screen.getByText("settings:providers.geminiCli.requirementsTitle").previousElementSibling + expect(warningIcon).toHaveClass("codicon-warning") + expect(warningIcon).toHaveClass("text-vscode-notificationsWarningIcon-foreground") + + // Check for check icon + const checkIcon = screen.getByText("settings:providers.geminiCli.freeAccess").previousElementSibling + expect(checkIcon).toHaveClass("codicon-check") + expect(checkIcon).toHaveClass("text-vscode-notificationsInfoIcon-foreground") + }) + + it("renders instructions with code element", () => { + render() + + const codeElement = screen.getByText("gemini") + expect(codeElement.tagName).toBe("CODE") + expect(codeElement).toHaveClass("text-vscode-textPreformat-foreground") + }) +}) diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index fe0e6cecf961..a27b9db0be1b 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -6,6 +6,7 @@ export { ClaudeCode } from "./ClaudeCode" export { DeepSeek } from "./DeepSeek" export { Doubao } from "./Doubao" export { Gemini } from "./Gemini" +export { GeminiCli } from "./GeminiCli" export { Glama } from "./Glama" export { Groq } from "./Groq" export { HuggingFace } from "./HuggingFace" diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index a3ce1e63e4e1..6a6ae84feae3 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -14,6 +14,8 @@ import { moonshotModels, geminiDefaultModelId, geminiModels, + geminiCliDefaultModelId, + geminiCliModels, mistralDefaultModelId, mistralModels, openAiModelInfoSaneDefaults, @@ -222,6 +224,11 @@ function getSelectedModel({ const info = geminiModels[id as keyof typeof geminiModels] return { id, info } } + case "gemini-cli": { + const id = apiConfiguration.apiModelId ?? geminiCliDefaultModelId + const info = geminiCliModels[id as keyof typeof geminiCliModels] + return { id, info } + } case "deepseek": { const id = apiConfiguration.apiModelId ?? deepSeekDefaultModelId const info = deepSeekModels[id as keyof typeof deepSeekModels] diff --git a/webview-ui/src/i18n/locales/en/settings.json b/webview-ui/src/i18n/locales/en/settings.json index dfccc49cc4ce..926c8782fe88 100644 --- a/webview-ui/src/i18n/locales/en/settings.json +++ b/webview-ui/src/i18n/locales/en/settings.json @@ -484,6 +484,21 @@ "placeholder": "Default: claude", "maxTokensLabel": "Max Output Tokens", "maxTokensDescription": "Maximum number of output tokens for Claude Code responses. Default is 8000." + }, + "geminiCli": { + "description": "This provider uses OAuth authentication from the Gemini CLI tool and does not require API keys.", + "oauthPath": "OAuth Credentials Path (optional)", + "oauthPathDescription": "Path to the OAuth credentials file. Leave empty to use the default location (~/.gemini/oauth_creds.json).", + "instructions": "If you haven't authenticated yet, please run", + "instructionsContinued": "in your terminal first.", + "setupLink": "Gemini CLI Setup Instructions", + "requirementsTitle": "Important Requirements", + "requirement1": "First, you need to install the Gemini CLI tool", + "requirement2": "Then, run gemini in your terminal and make sure you Log in with Google", + "requirement3": "Only works with personal Google accounts (not Google Workspace accounts)", + "requirement4": "Does not use API keys - authentication is handled via OAuth", + "requirement5": "Requires the Gemini CLI tool to be installed and authenticated first", + "freeAccess": "Free tier access via OAuth authentication" } }, "browser": { diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index d15f82e4cafc..53218e2a2001 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -88,6 +88,9 @@ function validateModelsAndKeysProvided(apiConfiguration: ProviderSettings): stri return i18next.t("settings:validation.apiKey") } break + case "gemini-cli": + // OAuth-based provider, no API key validation needed + break case "openai-native": if (!apiConfiguration.openAiNativeApiKey) { return i18next.t("settings:validation.apiKey")