Skip to content

Commit 96b7921

Browse files
committed
feat: vertex/gemini prompt caching
Moving to @google/genai and reusing the gemini provider when using gemini on vertex ai
1 parent 1924e10 commit 96b7921

File tree

12 files changed

+174
-636
lines changed

12 files changed

+174
-636
lines changed

.changeset/curly-frogs-pull.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"roo-cline": patch
3+
---
4+
5+
Use gemini provider when using Gemini on vertex ai

package-lock.json

Lines changed: 0 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@
404404
"@anthropic-ai/sdk": "^0.37.0",
405405
"@anthropic-ai/vertex-sdk": "^0.7.0",
406406
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
407-
"@google-cloud/vertexai": "^1.9.3",
408407
"@google/genai": "^0.9.0",
409408
"@mistralai/mistralai": "^1.3.6",
410409
"@modelcontextprotocol/sdk": "^1.7.0",

src/api/providers/__tests__/vertex.test.ts

Lines changed: 67 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
55

66
import { VertexHandler } from "../vertex"
77
import { ApiStreamChunk } from "../../transform/stream"
8-
import { VertexAI } from "@google-cloud/vertexai"
8+
import { GeminiHandler } from "../gemini"
99

1010
// Mock Vertex SDK
1111
jest.mock("@anthropic-ai/vertex-sdk", () => ({
@@ -49,58 +49,40 @@ jest.mock("@anthropic-ai/vertex-sdk", () => ({
4949
})),
5050
}))
5151

52-
// Mock Vertex Gemini SDK
53-
jest.mock("@google-cloud/vertexai", () => {
54-
const mockGenerateContentStream = jest.fn().mockImplementation(() => {
55-
return {
56-
stream: {
57-
async *[Symbol.asyncIterator]() {
58-
yield {
59-
candidates: [
60-
{
61-
content: {
62-
parts: [{ text: "Test Gemini response" }],
63-
},
64-
},
65-
],
66-
}
67-
},
52+
jest.mock("../gemini", () => {
53+
const mockGeminiHandler = jest.fn()
54+
55+
mockGeminiHandler.prototype.createMessage = jest.fn().mockImplementation(async function* () {
56+
const mockStream: ApiStreamChunk[] = [
57+
{
58+
type: "usage",
59+
inputTokens: 10,
60+
outputTokens: 0,
6861
},
69-
response: {
70-
usageMetadata: {
71-
promptTokenCount: 5,
72-
candidatesTokenCount: 10,
73-
},
62+
{
63+
type: "text",
64+
text: "Gemini response part 1",
7465
},
75-
}
76-
})
77-
78-
const mockGenerateContent = jest.fn().mockResolvedValue({
79-
response: {
80-
candidates: [
81-
{
82-
content: {
83-
parts: [{ text: "Test Gemini response" }],
84-
},
85-
},
86-
],
87-
},
88-
})
66+
{
67+
type: "text",
68+
text: " part 2",
69+
},
70+
{
71+
type: "usage",
72+
inputTokens: 0,
73+
outputTokens: 5,
74+
},
75+
]
8976

90-
const mockGenerativeModel = jest.fn().mockImplementation(() => {
91-
return {
92-
generateContentStream: mockGenerateContentStream,
93-
generateContent: mockGenerateContent,
77+
for (const chunk of mockStream) {
78+
yield chunk
9479
}
9580
})
9681

82+
mockGeminiHandler.prototype.completePrompt = jest.fn().mockResolvedValue("Test Gemini response")
83+
9784
return {
98-
VertexAI: jest.fn().mockImplementation(() => {
99-
return {
100-
getGenerativeModel: mockGenerativeModel,
101-
}
102-
}),
103-
GenerativeModel: mockGenerativeModel,
85+
GeminiHandler: mockGeminiHandler,
10486
}
10587
})
10688

@@ -128,9 +110,11 @@ describe("VertexHandler", () => {
128110
vertexRegion: "us-central1",
129111
})
130112

131-
expect(VertexAI).toHaveBeenCalledWith({
132-
project: "test-project",
133-
location: "us-central1",
113+
expect(GeminiHandler).toHaveBeenCalledWith({
114+
isVertex: true,
115+
apiModelId: "gemini-1.5-pro-001",
116+
vertexProjectId: "test-project",
117+
vertexRegion: "us-central1",
134118
})
135119
})
136120

@@ -270,48 +254,48 @@ describe("VertexHandler", () => {
270254
})
271255

272256
it("should handle streaming responses correctly for Gemini", async () => {
273-
const mockGemini = require("@google-cloud/vertexai")
274-
const mockGenerateContentStream = mockGemini.VertexAI().getGenerativeModel().generateContentStream
275257
handler = new VertexHandler({
276258
apiModelId: "gemini-1.5-pro-001",
277259
vertexProjectId: "test-project",
278260
vertexRegion: "us-central1",
279261
})
280262

281-
const stream = handler.createMessage(systemPrompt, mockMessages)
263+
const mockCacheKey = "cacheKey"
264+
const mockGeminiHandlerInstance = (GeminiHandler as jest.Mock).mock.instances[0]
265+
266+
const stream = handler.createMessage(systemPrompt, mockMessages, mockCacheKey)
267+
282268
const chunks: ApiStreamChunk[] = []
283269

284270
for await (const chunk of stream) {
285271
chunks.push(chunk)
286272
}
287273

288-
expect(chunks.length).toBe(2)
274+
expect(chunks.length).toBe(4)
289275
expect(chunks[0]).toEqual({
290-
type: "text",
291-
text: "Test Gemini response",
276+
type: "usage",
277+
inputTokens: 10,
278+
outputTokens: 0,
292279
})
293280
expect(chunks[1]).toEqual({
281+
type: "text",
282+
text: "Gemini response part 1",
283+
})
284+
expect(chunks[2]).toEqual({
285+
type: "text",
286+
text: " part 2",
287+
})
288+
expect(chunks[3]).toEqual({
294289
type: "usage",
295-
inputTokens: 5,
296-
outputTokens: 10,
290+
inputTokens: 0,
291+
outputTokens: 5,
297292
})
298293

299-
expect(mockGenerateContentStream).toHaveBeenCalledWith({
300-
contents: [
301-
{
302-
role: "user",
303-
parts: [{ text: "Hello" }],
304-
},
305-
{
306-
role: "model",
307-
parts: [{ text: "Hi there!" }],
308-
},
309-
],
310-
generationConfig: {
311-
maxOutputTokens: 8192,
312-
temperature: 0,
313-
},
314-
})
294+
expect(mockGeminiHandlerInstance.createMessage).toHaveBeenCalledWith(
295+
systemPrompt,
296+
mockMessages,
297+
mockCacheKey,
298+
)
315299
})
316300

317301
it("should handle multiple content blocks with line breaks for Claude", async () => {
@@ -753,9 +737,6 @@ describe("VertexHandler", () => {
753737
})
754738

755739
it("should complete prompt successfully for Gemini", async () => {
756-
const mockGemini = require("@google-cloud/vertexai")
757-
const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
758-
759740
handler = new VertexHandler({
760741
apiModelId: "gemini-1.5-pro-001",
761742
vertexProjectId: "test-project",
@@ -764,13 +745,9 @@ describe("VertexHandler", () => {
764745

765746
const result = await handler.completePrompt("Test prompt")
766747
expect(result).toBe("Test Gemini response")
767-
expect(mockGenerateContent).toHaveBeenCalled()
768-
expect(mockGenerateContent).toHaveBeenCalledWith({
769-
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
770-
generationConfig: {
771-
temperature: 0,
772-
},
773-
})
748+
749+
const mockGeminiHandlerInstance = (GeminiHandler as jest.Mock).mock.instances[0]
750+
expect(mockGeminiHandlerInstance.completePrompt).toHaveBeenCalledWith("Test prompt")
774751
})
775752

776753
it("should handle API errors for Claude", async () => {
@@ -790,17 +767,17 @@ describe("VertexHandler", () => {
790767
})
791768

792769
it("should handle API errors for Gemini", async () => {
793-
const mockGemini = require("@google-cloud/vertexai")
794-
const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
795-
mockGenerateContent.mockRejectedValue(new Error("Vertex API error"))
770+
const mockGeminiHandlerInstance = (GeminiHandler as jest.Mock).mock.instances[0]
771+
mockGeminiHandlerInstance.completePrompt.mockRejectedValue(new Error("Vertex API error"))
772+
796773
handler = new VertexHandler({
797774
apiModelId: "gemini-1.5-pro-001",
798775
vertexProjectId: "test-project",
799776
vertexRegion: "us-central1",
800777
})
801778

802779
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
803-
"Vertex completion error: Vertex API error",
780+
"Vertex API error", // Expecting the raw error message from the mock
804781
)
805782
})
806783

@@ -837,19 +814,9 @@ describe("VertexHandler", () => {
837814
})
838815

839816
it("should handle empty response for Gemini", async () => {
840-
const mockGemini = require("@google-cloud/vertexai")
841-
const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
842-
mockGenerateContent.mockResolvedValue({
843-
response: {
844-
candidates: [
845-
{
846-
content: {
847-
parts: [{ text: "" }],
848-
},
849-
},
850-
],
851-
},
852-
})
817+
const mockGeminiHandlerInstance = (GeminiHandler as jest.Mock).mock.instances[0]
818+
mockGeminiHandlerInstance.completePrompt.mockResolvedValue("")
819+
853820
handler = new VertexHandler({
854821
apiModelId: "gemini-1.5-pro-001",
855822
vertexProjectId: "test-project",

src/api/providers/gemini.ts

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import {
88
import NodeCache from "node-cache"
99

1010
import { SingleCompletionHandler } from "../"
11-
import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
12-
import { geminiDefaultModelId, geminiModels } from "../../shared/api"
11+
import type { ApiHandlerOptions, GeminiModelId, VertexModelId, ModelInfo } from "../../shared/api"
12+
import { geminiDefaultModelId, geminiModels, vertexDefaultModelId, vertexModels } from "../../shared/api"
1313
import {
1414
convertAnthropicContentToGemini,
1515
convertAnthropicMessageToGemini,
@@ -37,10 +37,43 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
3737
constructor(options: ApiHandlerOptions) {
3838
super()
3939
this.options = options
40-
this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })
40+
41+
this.client = this.initializeClient()
4142
this.contentCaches = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
4243
}
4344

45+
private initializeClient(): GoogleGenAI {
46+
if (this.options.isVertex !== true) {
47+
return new GoogleGenAI({ apiKey: this.options.geminiApiKey ?? "not-provided" })
48+
}
49+
50+
if (this.options.vertexJsonCredentials) {
51+
return new GoogleGenAI({
52+
vertexai: true,
53+
project: this.options.vertexProjectId ?? "not-provided",
54+
location: this.options.vertexRegion ?? "not-provided",
55+
googleAuthOptions: {
56+
credentials: JSON.parse(this.options.vertexJsonCredentials),
57+
},
58+
})
59+
} else if (this.options.vertexKeyFile) {
60+
return new GoogleGenAI({
61+
vertexai: true,
62+
project: this.options.vertexProjectId ?? "not-provided",
63+
location: this.options.vertexRegion ?? "not-provided",
64+
googleAuthOptions: {
65+
keyFile: this.options.vertexKeyFile,
66+
},
67+
})
68+
} else {
69+
return new GoogleGenAI({
70+
vertexai: true,
71+
project: this.options.vertexProjectId ?? "not-provided",
72+
location: this.options.vertexRegion ?? "not-provided",
73+
})
74+
}
75+
}
76+
4477
async *createMessage(
4578
systemInstruction: string,
4679
messages: Anthropic.Messages.MessageParam[],
@@ -170,6 +203,10 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
170203
}
171204

172205
override getModel() {
206+
if (this.options.isVertex === true) {
207+
return this.getVertexModel()
208+
}
209+
173210
let id = this.options.apiModelId ? (this.options.apiModelId as GeminiModelId) : geminiDefaultModelId
174211
let info: ModelInfo = geminiModels[id]
175212

@@ -198,6 +235,35 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
198235
return { id, info }
199236
}
200237

238+
private getVertexModel() {
239+
let id = this.options.apiModelId ? (this.options.apiModelId as VertexModelId) : vertexDefaultModelId
240+
let info: ModelInfo = vertexModels[id]
241+
242+
if (id?.endsWith(":thinking")) {
243+
id = id.slice(0, -":thinking".length) as VertexModelId
244+
245+
if (vertexModels[id]) {
246+
info = vertexModels[id]
247+
248+
return {
249+
id,
250+
info,
251+
thinkingConfig: this.options.modelMaxThinkingTokens
252+
? { thinkingBudget: this.options.modelMaxThinkingTokens }
253+
: undefined,
254+
maxOutputTokens: this.options.modelMaxTokens ?? info.maxTokens ?? undefined,
255+
}
256+
}
257+
}
258+
259+
if (!info) {
260+
id = vertexDefaultModelId
261+
info = vertexModels[vertexDefaultModelId]
262+
}
263+
264+
return { id, info }
265+
}
266+
201267
async completePrompt(prompt: string): Promise<string> {
202268
try {
203269
const { id: model } = this.getModel()

0 commit comments

Comments
 (0)