Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/curly-frogs-pull.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"roo-cline": patch
---

Use gemini provider when using Gemini on vertex ai
13 changes: 0 additions & 13 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@
"@anthropic-ai/sdk": "^0.37.0",
"@anthropic-ai/vertex-sdk": "^0.7.0",
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
"@google-cloud/vertexai": "^1.9.3",
"@google/genai": "^0.9.0",
"@mistralai/mistralai": "^1.3.6",
"@modelcontextprotocol/sdk": "^1.7.0",
Expand Down
167 changes: 67 additions & 100 deletions src/api/providers/__tests__/vertex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"

import { VertexHandler } from "../vertex"
import { ApiStreamChunk } from "../../transform/stream"
import { VertexAI } from "@google-cloud/vertexai"
import { GeminiHandler } from "../gemini"

// Mock Vertex SDK
jest.mock("@anthropic-ai/vertex-sdk", () => ({
Expand Down Expand Up @@ -49,58 +49,40 @@ jest.mock("@anthropic-ai/vertex-sdk", () => ({
})),
}))

// Mock Vertex Gemini SDK
jest.mock("@google-cloud/vertexai", () => {
const mockGenerateContentStream = jest.fn().mockImplementation(() => {
return {
stream: {
async *[Symbol.asyncIterator]() {
yield {
candidates: [
{
content: {
parts: [{ text: "Test Gemini response" }],
},
},
],
}
},
jest.mock("../gemini", () => {
const mockGeminiHandler = jest.fn()

mockGeminiHandler.prototype.createMessage = jest.fn().mockImplementation(async function* () {
const mockStream: ApiStreamChunk[] = [
{
type: "usage",
inputTokens: 10,
outputTokens: 0,
},
response: {
usageMetadata: {
promptTokenCount: 5,
candidatesTokenCount: 10,
},
{
type: "text",
text: "Gemini response part 1",
},
}
})

const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
candidates: [
{
content: {
parts: [{ text: "Test Gemini response" }],
},
},
],
},
})
{
type: "text",
text: " part 2",
},
{
type: "usage",
inputTokens: 0,
outputTokens: 5,
},
]

const mockGenerativeModel = jest.fn().mockImplementation(() => {
return {
generateContentStream: mockGenerateContentStream,
generateContent: mockGenerateContent,
for (const chunk of mockStream) {
yield chunk
}
})

mockGeminiHandler.prototype.completePrompt = jest.fn().mockResolvedValue("Test Gemini response")

return {
VertexAI: jest.fn().mockImplementation(() => {
return {
getGenerativeModel: mockGenerativeModel,
}
}),
GenerativeModel: mockGenerativeModel,
GeminiHandler: mockGeminiHandler,
}
})

Expand Down Expand Up @@ -128,9 +110,11 @@ describe("VertexHandler", () => {
vertexRegion: "us-central1",
})

expect(VertexAI).toHaveBeenCalledWith({
project: "test-project",
location: "us-central1",
expect(GeminiHandler).toHaveBeenCalledWith({
isVertex: true,
apiModelId: "gemini-1.5-pro-001",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})
})

Expand Down Expand Up @@ -270,48 +254,48 @@ describe("VertexHandler", () => {
})

it("should handle streaming responses correctly for Gemini", async () => {
const mockGemini = require("@google-cloud/vertexai")
const mockGenerateContentStream = mockGemini.VertexAI().getGenerativeModel().generateContentStream
handler = new VertexHandler({
apiModelId: "gemini-1.5-pro-001",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const stream = handler.createMessage(systemPrompt, mockMessages)
const mockCacheKey = "cacheKey"
const mockGeminiHandlerInstance = (GeminiHandler as jest.Mock).mock.instances[0]

const stream = handler.createMessage(systemPrompt, mockMessages, mockCacheKey)

const chunks: ApiStreamChunk[] = []

for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks.length).toBe(2)
expect(chunks.length).toBe(4)
expect(chunks[0]).toEqual({
type: "text",
text: "Test Gemini response",
type: "usage",
inputTokens: 10,
outputTokens: 0,
})
expect(chunks[1]).toEqual({
type: "text",
text: "Gemini response part 1",
})
expect(chunks[2]).toEqual({
type: "text",
text: " part 2",
})
expect(chunks[3]).toEqual({
type: "usage",
inputTokens: 5,
outputTokens: 10,
inputTokens: 0,
outputTokens: 5,
})

expect(mockGenerateContentStream).toHaveBeenCalledWith({
contents: [
{
role: "user",
parts: [{ text: "Hello" }],
},
{
role: "model",
parts: [{ text: "Hi there!" }],
},
],
generationConfig: {
maxOutputTokens: 8192,
temperature: 0,
},
})
expect(mockGeminiHandlerInstance.createMessage).toHaveBeenCalledWith(
systemPrompt,
mockMessages,
mockCacheKey,
)
})

it("should handle multiple content blocks with line breaks for Claude", async () => {
Expand Down Expand Up @@ -753,9 +737,6 @@ describe("VertexHandler", () => {
})

it("should complete prompt successfully for Gemini", async () => {
const mockGemini = require("@google-cloud/vertexai")
const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent

handler = new VertexHandler({
apiModelId: "gemini-1.5-pro-001",
vertexProjectId: "test-project",
Expand All @@ -764,13 +745,9 @@ describe("VertexHandler", () => {

const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test Gemini response")
expect(mockGenerateContent).toHaveBeenCalled()
expect(mockGenerateContent).toHaveBeenCalledWith({
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
generationConfig: {
temperature: 0,
},
})

const mockGeminiHandlerInstance = (GeminiHandler as jest.Mock).mock.instances[0]
expect(mockGeminiHandlerInstance.completePrompt).toHaveBeenCalledWith("Test prompt")
})

it("should handle API errors for Claude", async () => {
Expand All @@ -790,17 +767,17 @@ describe("VertexHandler", () => {
})

it("should handle API errors for Gemini", async () => {
const mockGemini = require("@google-cloud/vertexai")
const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
mockGenerateContent.mockRejectedValue(new Error("Vertex API error"))
const mockGeminiHandlerInstance = (GeminiHandler as jest.Mock).mock.instances[0]
mockGeminiHandlerInstance.completePrompt.mockRejectedValue(new Error("Vertex API error"))

handler = new VertexHandler({
apiModelId: "gemini-1.5-pro-001",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Vertex completion error: Vertex API error",
"Vertex API error", // Expecting the raw error message from the mock
)
})

Expand Down Expand Up @@ -837,19 +814,9 @@ describe("VertexHandler", () => {
})

it("should handle empty response for Gemini", async () => {
const mockGemini = require("@google-cloud/vertexai")
const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
mockGenerateContent.mockResolvedValue({
response: {
candidates: [
{
content: {
parts: [{ text: "" }],
},
},
],
},
})
const mockGeminiHandlerInstance = (GeminiHandler as jest.Mock).mock.instances[0]
mockGeminiHandlerInstance.completePrompt.mockResolvedValue("")

handler = new VertexHandler({
apiModelId: "gemini-1.5-pro-001",
vertexProjectId: "test-project",
Expand Down
72 changes: 69 additions & 3 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import {
import NodeCache from "node-cache"

import { SingleCompletionHandler } from "../"
import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
import { geminiDefaultModelId, geminiModels } from "../../shared/api"
import type { ApiHandlerOptions, GeminiModelId, VertexModelId, ModelInfo } from "../../shared/api"
import { geminiDefaultModelId, geminiModels, vertexDefaultModelId, vertexModels } from "../../shared/api"
import {
convertAnthropicContentToGemini,
convertAnthropicMessageToGemini,
Expand Down Expand Up @@ -37,10 +37,43 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
constructor(options: ApiHandlerOptions) {
super()
this.options = options
this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })

this.client = this.initializeClient()
this.contentCaches = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
}

private initializeClient(): GoogleGenAI {
if (this.options.isVertex !== true) {
return new GoogleGenAI({ apiKey: this.options.geminiApiKey ?? "not-provided" })
}

if (this.options.vertexJsonCredentials) {
return new GoogleGenAI({
vertexai: true,
project: this.options.vertexProjectId ?? "not-provided",
location: this.options.vertexRegion ?? "not-provided",
googleAuthOptions: {
credentials: JSON.parse(this.options.vertexJsonCredentials),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider wrapping JSON.parse(this.options.vertexJsonCredentials) in a try/catch block to handle potential JSON parsing errors. This improves resiliency if invalid credentials are provided.

This comment was generated because it violated a code review rule: mrule_OR1S8PRRHcvbdFib.

},
})
} else if (this.options.vertexKeyFile) {
return new GoogleGenAI({
vertexai: true,
project: this.options.vertexProjectId ?? "not-provided",
location: this.options.vertexRegion ?? "not-provided",
googleAuthOptions: {
keyFile: this.options.vertexKeyFile,
},
})
} else {
return new GoogleGenAI({
vertexai: true,
project: this.options.vertexProjectId ?? "not-provided",
location: this.options.vertexRegion ?? "not-provided",
})
}
}

async *createMessage(
systemInstruction: string,
messages: Anthropic.Messages.MessageParam[],
Expand Down Expand Up @@ -170,6 +203,10 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
}

override getModel() {
if (this.options.isVertex === true) {
return this.getVertexModel()
}

let id = this.options.apiModelId ? (this.options.apiModelId as GeminiModelId) : geminiDefaultModelId
let info: ModelInfo = geminiModels[id]

Expand Down Expand Up @@ -198,6 +235,35 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
return { id, info }
}

private getVertexModel() {
let id = this.options.apiModelId ? (this.options.apiModelId as VertexModelId) : vertexDefaultModelId
let info: ModelInfo = vertexModels[id]

if (id?.endsWith(":thinking")) {
id = id.slice(0, -":thinking".length) as VertexModelId

if (vertexModels[id]) {
info = vertexModels[id]

return {
id,
info,
thinkingConfig: this.options.modelMaxThinkingTokens
? { thinkingBudget: this.options.modelMaxThinkingTokens }
: undefined,
maxOutputTokens: this.options.modelMaxTokens ?? info.maxTokens ?? undefined,
}
}
}

if (!info) {
id = vertexDefaultModelId
info = vertexModels[vertexDefaultModelId]
}

return { id, info }
}

async completePrompt(prompt: string): Promise<string> {
try {
const { id: model } = this.getModel()
Expand Down
Loading