diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 3fa7094d87..2b323e21cb 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -12,6 +12,7 @@ import { doubaoModels, featherlessModels, fireworksModels, + geminiCliModels, geminiModels, groqModels, ioIntelligenceModels, @@ -440,7 +441,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str } export const MODELS_BY_PROVIDER: Record< - Exclude, + Exclude, { id: ProviderName; label: string; models: string[] } > = { anthropic: { @@ -485,6 +486,11 @@ export const MODELS_BY_PROVIDER: Record< label: "Google Gemini", models: Object.keys(geminiModels), }, + "gemini-cli": { + id: "gemini-cli", + label: "Gemini CLI", + models: Object.keys(geminiCliModels), + }, groq: { id: "groq", label: "Groq", models: Object.keys(groqModels) }, "io-intelligence": { id: "io-intelligence", diff --git a/packages/types/src/providers/gemini-cli.ts b/packages/types/src/providers/gemini-cli.ts new file mode 100644 index 0000000000..231fd3c57b --- /dev/null +++ b/packages/types/src/providers/gemini-cli.ts @@ -0,0 +1,207 @@ +import type { ModelInfo } from "../model.js" + +// Gemini CLI models - using the same models as regular Gemini +// but accessed through the @google/gemini-cli-core library +export type GeminiCliModelId = keyof typeof geminiCliModels + +export const geminiCliDefaultModelId: GeminiCliModelId = "gemini-2.0-flash-001" + +// Re-use the same model definitions as regular Gemini since they're the same models +// just accessed through a different authentication mechanism (OAuth via CLI) +export const geminiCliModels = { + "gemini-2.5-flash-preview-04-17:thinking": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 3.5, + maxThinkingTokens: 24_576, + supportsReasoningBudget: true, + requiredReasoningBudget: true, + }, + "gemini-2.5-flash-preview-04-17": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.15, + outputPrice: 0.6, + }, + "gemini-2.5-flash-preview-05-20:thinking": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0.15, + outputPrice: 3.5, + cacheReadsPrice: 0.0375, + cacheWritesPrice: 1.0, + maxThinkingTokens: 24_576, + supportsReasoningBudget: true, + requiredReasoningBudget: true, + }, + "gemini-2.5-flash-preview-05-20": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0.15, + outputPrice: 0.6, + cacheReadsPrice: 0.0375, + cacheWritesPrice: 1.0, + }, + "gemini-2.5-flash": { + maxTokens: 64_000, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0.3, + outputPrice: 2.5, + cacheReadsPrice: 0.075, + cacheWritesPrice: 1.0, + maxThinkingTokens: 24_576, + supportsReasoningBudget: true, + }, + "gemini-2.5-pro-exp-03-25": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.5-pro-preview-03-25": { + maxTokens: 65_535, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 2.5, + outputPrice: 15, + cacheReadsPrice: 0.625, + cacheWritesPrice: 4.5, + tiers: [ + { + contextWindow: 200_000, + inputPrice: 1.25, + outputPrice: 10, + cacheReadsPrice: 0.31, + }, + { + contextWindow: Infinity, + inputPrice: 2.5, + outputPrice: 15, + cacheReadsPrice: 0.625, + }, + ], + }, + "gemini-2.5-pro": { + maxTokens: 64_000, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 2.5, + outputPrice: 15, + cacheReadsPrice: 0.625, + cacheWritesPrice: 4.5, + maxThinkingTokens: 32_768, + supportsReasoningBudget: true, + requiredReasoningBudget: true, + tiers: [ + { + contextWindow: 200_000, + inputPrice: 1.25, + outputPrice: 10, + cacheReadsPrice: 0.31, + }, + { + contextWindow: Infinity, + inputPrice: 2.5, + outputPrice: 15, + cacheReadsPrice: 0.625, + }, + ], + }, + "gemini-2.0-flash-001": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0.1, + outputPrice: 0.4, + cacheReadsPrice: 0.025, + cacheWritesPrice: 1.0, + }, + "gemini-2.0-flash-lite-preview-02-05": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.0-pro-exp-02-05": { + maxTokens: 8192, + contextWindow: 2_097_152, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.0-flash-thinking-exp-01-21": { + maxTokens: 65_536, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.0-flash-thinking-exp-1219": { + maxTokens: 8192, + contextWindow: 32_767, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-2.0-flash-exp": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemini-1.5-flash-002": { + maxTokens: 8192, + contextWindow: 1_048_576, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0.15, + outputPrice: 0.6, + cacheReadsPrice: 0.0375, + cacheWritesPrice: 1.0, + tiers: [ + { + contextWindow: 128_000, + inputPrice: 0.075, + outputPrice: 0.3, + cacheReadsPrice: 0.01875, + }, + { + contextWindow: Infinity, + inputPrice: 0.15, + outputPrice: 0.6, + cacheReadsPrice: 0.0375, + }, + ], + }, + "gemini-1.5-pro-002": { + maxTokens: 8192, + contextWindow: 2_097_152, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, +} as const satisfies Record diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index 8ca9c2c9b2..d28ceb7884 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -7,6 +7,7 @@ export * from "./deepseek.js" export * from "./doubao.js" export * from "./featherless.js" export * from "./fireworks.js" +export * from "./gemini-cli.js" export * from "./gemini.js" export * from "./glama.js" export * from "./groq.js" diff --git a/src/api/index.ts b/src/api/index.ts index 6c70a1485d..a956875dfe 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -14,6 +14,7 @@ import { AnthropicVertexHandler, OpenAiHandler, LmStudioHandler, + GeminiCliHandler, GeminiHandler, OpenAiNativeHandler, DeepSeekHandler, @@ -102,6 +103,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new LmStudioHandler(options) case "gemini": return new GeminiHandler(options) + case "gemini-cli": + return new GeminiCliHandler(options) case "openai-native": return new OpenAiNativeHandler(options) case "deepseek": @@ -149,7 +152,7 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { case "featherless": return new FeatherlessHandler(options) default: - apiProvider satisfies "gemini-cli" | undefined + apiProvider satisfies undefined return new AnthropicHandler(options) } } diff --git a/src/api/providers/__tests__/gemini-cli.spec.ts b/src/api/providers/__tests__/gemini-cli.spec.ts new file mode 100644 index 0000000000..3d033998d7 --- /dev/null +++ b/src/api/providers/__tests__/gemini-cli.spec.ts @@ -0,0 +1,482 @@ +// Mocks must come first, before imports +const mockGenerateContentStream = vi.fn() +const mockGenerateContent = vi.fn() +const mockCountTokens = vi.fn() + +vi.mock("@google/gemini-cli-core", () => { + return { + GeminiCLI: vi.fn().mockImplementation(() => ({ + models: { + generateContentStream: mockGenerateContentStream, + generateContent: mockGenerateContent, + countTokens: mockCountTokens, + }, + })), + } +}) + +import type { Anthropic } from "@anthropic-ai/sdk" + +import { geminiCliDefaultModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" + +import { GeminiCliHandler } from "../gemini-cli" + +describe("GeminiCliHandler", () => { + let handler: GeminiCliHandler + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + geminiCliOAuthPath: "~/.config/gemini-cli/oauth.json", + geminiCliProjectId: "test-project", + apiModelId: "gemini-2.0-flash-001", + } + // Reset mocks + vi.clearAllMocks() + + // Setup default mock responses + mockGenerateContentStream.mockImplementation(async () => { + return { + [Symbol.asyncIterator]: async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: "Test response" }], + }, + }, + ], + usageMetadata: { + promptTokenCount: 10, + candidatesTokenCount: 5, + cachedContentTokenCount: 2, + }, + } + }, + } + }) + + mockGenerateContent.mockResolvedValue({ + text: "Test prompt response", + candidates: [ + { + content: { + parts: [{ text: "Test prompt response" }], + }, + }, + ], + }) + + mockCountTokens.mockResolvedValue({ + totalTokens: 15, + }) + }) + + describe("constructor", () => { + it("should initialize with provided options", async () => { + handler = new GeminiCliHandler(mockOptions) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + expect(handler).toBeInstanceOf(GeminiCliHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) + + it("should use default OAuth path if not provided", async () => { + const handlerWithoutPath = new GeminiCliHandler({ + ...mockOptions, + geminiCliOAuthPath: undefined, + }) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + expect(handlerWithoutPath).toBeInstanceOf(GeminiCliHandler) + }) + + it("should use default model ID if not provided", async () => { + const handlerWithoutModel = new GeminiCliHandler({ + ...mockOptions, + apiModelId: undefined, + }) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + expect(handlerWithoutModel.getModel().id).toBe(geminiCliDefaultModelId) + }) + + it("should handle project ID configuration", async () => { + const handlerWithProject = new GeminiCliHandler({ + ...mockOptions, + geminiCliProjectId: "custom-project", + }) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + expect(handlerWithProject).toBeInstanceOf(GeminiCliHandler) + }) + }) + + describe("getModel", () => { + beforeEach(async () => { + handler = new GeminiCliHandler(mockOptions) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + }) + + it("should return model info for valid model ID", () => { + const model = handler.getModel() + expect(model.id).toBe(mockOptions.apiModelId) + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(1_048_576) + expect(model.info.supportsImages).toBe(true) + expect(model.info.supportsPromptCache).toBe(true) + }) + + it("should handle thinking models by removing :thinking suffix", () => { + const handlerWithThinking = new GeminiCliHandler({ + ...mockOptions, + apiModelId: "gemini-2.5-flash-preview-04-17:thinking", + }) + const model = handlerWithThinking.getModel() + expect(model.id).toBe("gemini-2.5-flash-preview-04-17") // :thinking suffix removed + expect(model.info.maxThinkingTokens).toBe(24_576) + expect(model.info.supportsReasoningBudget).toBe(true) + expect(model.info.requiredReasoningBudget).toBe(true) + }) + + it("should return default model if invalid model ID is provided", () => { + const handlerWithInvalidModel = new GeminiCliHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe(geminiCliDefaultModelId) + expect(model.info).toBeDefined() + }) + + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") + }) + }) + + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] + + beforeEach(async () => { + handler = new GeminiCliHandler(mockOptions) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + }) + + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") + }) + + it("should include usage information", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(5) + expect(usageChunks[0].cacheReadTokens).toBe(2) + }) + + it("should handle reasoning/thinking parts", async () => { + mockGenerateContentStream.mockImplementationOnce(async () => { + return { + [Symbol.asyncIterator]: async function* () { + yield { + candidates: [ + { + content: { + parts: [ + { thought: true, text: "Let me think..." }, + { text: "Here's the answer" }, + ], + }, + }, + ], + usageMetadata: { + promptTokenCount: 10, + candidatesTokenCount: 5, + thoughtsTokenCount: 3, + }, + } + }, + } + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") + expect(reasoningChunks).toHaveLength(1) + expect(reasoningChunks[0].text).toBe("Let me think...") + + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Here's the answer") + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks[0].reasoningTokens).toBe(3) + }) + + it("should handle grounding metadata with citations", async () => { + mockGenerateContentStream.mockImplementationOnce(async () => { + return { + [Symbol.asyncIterator]: async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: "Test response" }], + }, + groundingMetadata: { + groundingChunks: [ + { web: { uri: "https://example.com/1" } }, + { web: { uri: "https://example.com/2" } }, + ], + }, + }, + ], + usageMetadata: { + promptTokenCount: 10, + candidatesTokenCount: 5, + }, + } + }, + } + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((chunk) => chunk.type === "text") + // Should have response text and citation text + expect(textChunks.length).toBeGreaterThan(1) + const citationChunk = textChunks.find((chunk) => chunk.text.includes("[1]")) + expect(citationChunk).toBeDefined() + expect(citationChunk?.text).toContain("https://example.com/1") + expect(citationChunk?.text).toContain("https://example.com/2") + }) + }) + + describe("completePrompt", () => { + beforeEach(async () => { + handler = new GeminiCliHandler(mockOptions) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + }) + + it("should complete a prompt successfully", async () => { + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test prompt response") + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockOptions.apiModelId, + contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], + }), + ) + }) + + it("should handle grounding metadata in prompt completion", async () => { + mockGenerateContent.mockResolvedValueOnce({ + text: "Test response", + candidates: [ + { + content: { + parts: [{ text: "Test response" }], + }, + groundingMetadata: { + groundingChunks: [{ web: { uri: "https://example.com" } }], + }, + }, + ], + }) + + const result = await handler.completePrompt("Test prompt") + expect(result).toContain("Test response") + expect(result).toContain("https://example.com") + }) + + it("should handle errors gracefully", async () => { + mockGenerateContent.mockRejectedValueOnce(new Error("API error")) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow() + }) + }) + + describe("countTokens", () => { + beforeEach(async () => { + handler = new GeminiCliHandler(mockOptions) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + }) + + it("should count tokens successfully", async () => { + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Test content", + }, + ] + + const result = await handler.countTokens(content) + expect(result).toBe(15) + expect(mockCountTokens).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockOptions.apiModelId, + }), + ) + }) + + it("should fall back to base implementation if counting fails", async () => { + mockCountTokens.mockRejectedValueOnce(new Error("Count error")) + + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Test content", + }, + ] + + const result = await handler.countTokens(content) + // Should fall back to tiktoken-based counting + expect(result).toBeGreaterThan(0) + }) + + it("should fall back if totalTokens is undefined", async () => { + mockCountTokens.mockResolvedValueOnce({ + totalTokens: undefined, + }) + + const content: Anthropic.Messages.ContentBlockParam[] = [ + { + type: "text", + text: "Test content", + }, + ] + + const result = await handler.countTokens(content) + // Should fall back to tiktoken-based counting + expect(result).toBeGreaterThan(0) + }) + }) + + describe("calculateCost", () => { + beforeEach(async () => { + handler = new GeminiCliHandler(mockOptions) + // Wait for async initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + }) + + it("should calculate cost correctly", () => { + const model = handler.getModel() + const cost = (handler as any).calculateCost({ + info: model.info, + inputTokens: 1000, + outputTokens: 500, + cacheReadTokens: 100, + }) + + expect(cost).toBeDefined() + expect(cost).toBeGreaterThan(0) + }) + + it("should handle tiered pricing", () => { + const handlerWithTiered = new GeminiCliHandler({ + ...mockOptions, + apiModelId: "gemini-2.5-pro", + }) + const model = handlerWithTiered.getModel() + + // Test with tokens below tier threshold + const costLowTier = (handlerWithTiered as any).calculateCost({ + info: model.info, + inputTokens: 100_000, + outputTokens: 5000, + cacheReadTokens: 1000, + }) + + // Test with tokens above tier threshold + const costHighTier = (handlerWithTiered as any).calculateCost({ + info: model.info, + inputTokens: 300_000, + outputTokens: 5000, + cacheReadTokens: 1000, + }) + + expect(costLowTier).toBeDefined() + expect(costHighTier).toBeDefined() + // High tier should cost more due to higher input token count and different pricing + expect(costHighTier).toBeGreaterThan(costLowTier) + }) + + it("should return undefined if pricing info is missing", () => { + const model = handler.getModel() + const modifiedInfo = { ...model.info, inputPrice: undefined } + + const cost = (handler as any).calculateCost({ + info: modifiedInfo, + inputTokens: 1000, + outputTokens: 500, + }) + + expect(cost).toBeUndefined() + }) + }) + + describe("error handling", () => { + it("should handle missing gemini-cli-core package gracefully", async () => { + // Mock the dynamic import to fail + vi.doMock("@google/gemini-cli-core", () => { + throw new Error("Module not found") + }) + + // This will throw during initialization + const handlerPromise = new GeminiCliHandler(mockOptions) + + // Give it time to attempt initialization + await new Promise((resolve) => setTimeout(resolve, 150)) + + // Try to use the handler - should throw error about missing package + await expect(async () => { + const stream = (handlerPromise as any).createMessage("test", []) + for await (const _chunk of stream) { + // Should not reach here + } + }).rejects.toThrow() + }) + }) +}) diff --git a/src/api/providers/gemini-cli.ts b/src/api/providers/gemini-cli.ts new file mode 100644 index 0000000000..2a062f7ce7 --- /dev/null +++ b/src/api/providers/gemini-cli.ts @@ -0,0 +1,296 @@ +import type { Anthropic } from "@anthropic-ai/sdk" +import type { GenerateContentResponseUsageMetadata, GroundingMetadata } from "@google/genai" + +import { type ModelInfo, type GeminiCliModelId, geminiCliDefaultModelId, geminiCliModels } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../shared/api" + +import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" +import { t } from "i18next" +import type { ApiStream } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { BaseProvider } from "./base-provider" + +/** + * GeminiCliHandler provides integration with Google's Gemini models through the + * @google/gemini-cli-core library, which uses OAuth authentication via the Gemini CLI. + * + * This handler reuses much of the logic from the regular GeminiHandler but uses + * the Gemini CLI core library for authentication instead of API keys or service accounts. + */ +export class GeminiCliHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + private client: any // Will be typed when @google/gemini-cli-core is available + + constructor(options: ApiHandlerOptions) { + super() + this.options = options + + // Initialize the Gemini CLI client with OAuth configuration + // The OAuth path and project ID come from the provider settings + const oauthPath = this.options.geminiCliOAuthPath || "~/.config/gemini-cli/oauth.json" + const projectId = this.options.geminiCliProjectId + + // Dynamically import the Gemini CLI library + // This allows the code to compile even if the package isn't installed yet + this.initializeClient(oauthPath, projectId) + } + + private async initializeClient(oauthPath: string, projectId?: string) { + try { + // Dynamic import to handle missing package gracefully + // @ts-ignore - Package will be available at runtime + const { GeminiCLI } = await import("@google/gemini-cli-core") + this.client = new GeminiCLI({ + oauthPath, + projectId, + }) + } catch (error) { + throw new Error( + "@google/gemini-cli-core is not installed. Please install it with: npm install @google/gemini-cli-core", + ) + } + } + + private async ensureClientInitialized() { + if (!this.client) { + // Wait a bit for initialization to complete + await new Promise((resolve) => setTimeout(resolve, 100)) + if (!this.client) { + throw new Error("Gemini CLI client not initialized. Please check your configuration.") + } + } + } + + async *createMessage( + systemInstruction: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + await this.ensureClientInitialized() + const { id: model, info, reasoning: thinkingConfig, maxTokens } = this.getModel() + + const contents = messages.map(convertAnthropicMessageToGemini) + + const config = { + systemInstruction, + thinkingConfig, + maxOutputTokens: this.options.modelMaxTokens ?? maxTokens ?? undefined, + temperature: this.options.modelTemperature ?? 0, + } + + const params = { model, contents, config } + + try { + // Use the Gemini CLI client to generate content + // The actual API will depend on @google/gemini-cli-core implementation + const result = await this.client.models.generateContentStream(params) + + let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined + let pendingGroundingMetadata: GroundingMetadata | undefined + + for await (const chunk of result) { + // Process candidates and their parts to separate thoughts from content + if (chunk.candidates && chunk.candidates.length > 0) { + const candidate = chunk.candidates[0] + + if (candidate.groundingMetadata) { + pendingGroundingMetadata = candidate.groundingMetadata + } + + if (candidate.content && candidate.content.parts) { + for (const part of candidate.content.parts) { + if (part.thought) { + // This is a thinking/reasoning part + if (part.text) { + yield { type: "reasoning", text: part.text } + } + } else { + // This is regular content + if (part.text) { + yield { type: "text", text: part.text } + } + } + } + } + } + + // Fallback to the original text property if no candidates structure + else if (chunk.text) { + yield { type: "text", text: chunk.text } + } + + if (chunk.usageMetadata) { + lastUsageMetadata = chunk.usageMetadata + } + } + + if (pendingGroundingMetadata) { + const citations = this.extractCitationsOnly(pendingGroundingMetadata) + if (citations) { + yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` } + } + } + + if (lastUsageMetadata) { + const inputTokens = lastUsageMetadata.promptTokenCount ?? 0 + const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0 + const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount + const reasoningTokens = lastUsageMetadata.thoughtsTokenCount + + yield { + type: "usage", + inputTokens, + outputTokens, + cacheReadTokens, + reasoningTokens, + totalCost: this.calculateCost({ info, inputTokens, outputTokens, cacheReadTokens }), + } + } + } catch (error) { + if (error instanceof Error) { + throw new Error(t("common:errors.gemini.generate_stream", { error: error.message })) + } + + throw error + } + } + + override getModel() { + const modelId = this.options.apiModelId + let id = modelId && modelId in geminiCliModels ? (modelId as GeminiCliModelId) : geminiCliDefaultModelId + let info: ModelInfo = geminiCliModels[id] + const params = getModelParams({ format: "gemini", modelId: id, model: info, settings: this.options }) + + // The `:thinking` suffix indicates that the model is a "Hybrid" + // reasoning model and that reasoning is required to be enabled. + // The actual model ID honored by Gemini's API does not have this suffix. + return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params } + } + + private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null { + const chunks = groundingMetadata?.groundingChunks + + if (!chunks) { + return null + } + + const citationLinks = chunks + .map((chunk, i) => { + const uri = chunk.web?.uri + if (uri) { + return `[${i + 1}](${uri})` + } + return null + }) + .filter((link): link is string => link !== null) + + if (citationLinks.length > 0) { + return citationLinks.join(", ") + } + + return null + } + + async completePrompt(prompt: string): Promise { + try { + await this.ensureClientInitialized() + const { id: model } = this.getModel() + + const promptConfig = { + temperature: this.options.modelTemperature ?? 0, + } + + const result = await this.client.models.generateContent({ + model, + contents: [{ role: "user", parts: [{ text: prompt }] }], + config: promptConfig, + }) + + let text = result.text ?? "" + + const candidate = result.candidates?.[0] + if (candidate?.groundingMetadata) { + const citations = this.extractCitationsOnly(candidate.groundingMetadata) + if (citations) { + text += `\n\n${t("common:errors.gemini.sources")} ${citations}` + } + } + + return text + } catch (error) { + if (error instanceof Error) { + throw new Error(t("common:errors.gemini.generate_complete_prompt", { error: error.message })) + } + + throw error + } + } + + override async countTokens(content: Array): Promise { + try { + await this.ensureClientInitialized() + const { id: model } = this.getModel() + + const response = await this.client.models.countTokens({ + model, + contents: convertAnthropicContentToGemini(content), + }) + + if (response.totalTokens === undefined) { + console.warn("Gemini CLI token counting returned undefined, using fallback") + return super.countTokens(content) + } + + return response.totalTokens + } catch (error) { + console.warn("Gemini CLI token counting failed, using fallback", error) + return super.countTokens(content) + } + } + + public calculateCost({ + info, + inputTokens, + outputTokens, + cacheReadTokens = 0, + }: { + info: ModelInfo + inputTokens: number + outputTokens: number + cacheReadTokens?: number + }) { + if (!info.inputPrice || !info.outputPrice || !info.cacheReadsPrice) { + return undefined + } + + let inputPrice = info.inputPrice + let outputPrice = info.outputPrice + let cacheReadsPrice = info.cacheReadsPrice + + // If there's tiered pricing then adjust the input and output token prices + // based on the input tokens used. + if (info.tiers) { + const tier = info.tiers.find((tier) => inputTokens <= tier.contextWindow) + + if (tier) { + inputPrice = tier.inputPrice ?? inputPrice + outputPrice = tier.outputPrice ?? outputPrice + cacheReadsPrice = tier.cacheReadsPrice ?? cacheReadsPrice + } + } + + // Subtract the cached input tokens from the total input tokens. + const uncachedInputTokens = inputTokens - cacheReadTokens + + let cacheReadCost = cacheReadTokens > 0 ? cacheReadsPrice * (cacheReadTokens / 1_000_000) : 0 + + const inputTokensCost = inputPrice * (uncachedInputTokens / 1_000_000) + const outputTokensCost = outputPrice * (outputTokens / 1_000_000) + const totalCost = inputTokensCost + outputTokensCost + cacheReadCost + + return totalCost + } +} diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index d256fbbe55..0cd7f596b8 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -8,6 +8,7 @@ export { DeepSeekHandler } from "./deepseek" export { DoubaoHandler } from "./doubao" export { MoonshotHandler } from "./moonshot" export { FakeAIHandler } from "./fake-ai" +export { GeminiCliHandler } from "./gemini-cli" export { GeminiHandler } from "./gemini" export { GlamaHandler } from "./glama" export { GroqHandler } from "./groq" diff --git a/src/package.json b/src/package.json index 52949a006a..dbee7cdf1c 100644 --- a/src/package.json +++ b/src/package.json @@ -427,6 +427,7 @@ "@anthropic-ai/vertex-sdk": "^0.7.0", "@aws-sdk/client-bedrock-runtime": "^3.848.0", "@aws-sdk/credential-providers": "^3.848.0", + "@google/gemini-cli-core": "^1.0.0", "@google/genai": "^1.0.0", "@lmstudio/sdk": "^1.1.1", "@mistralai/mistralai": "^1.9.18",