Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 302 additions & 0 deletions src/api/providers/__tests__/vertex-gemini-urlcontext.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
// npx vitest run src/api/providers/__tests__/vertex-gemini-urlcontext.spec.ts
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test file comment could be more descriptive about what this test suite validates. Consider:

Suggested change
// npx vitest run src/api/providers/__tests__/vertex-gemini-urlcontext.spec.ts
// Tests for ensuring urlContext parameter is correctly handled between Gemini and Vertex AI providers
// npx vitest run src/api/providers/__tests__/vertex-gemini-urlcontext.spec.ts


import { Anthropic } from "@anthropic-ai/sdk"
import { GeminiHandler } from "../gemini"
import { VertexHandler } from "../vertex"

describe("Vertex vs Gemini urlContext handling", () => {
describe("GeminiHandler", () => {
it("should include urlContext tool when enableUrlContext is true", async () => {
const mockGenerateContentStream = vitest.fn()

const handler = new GeminiHandler({
geminiApiKey: "test-key",
apiModelId: "gemini-1.5-flash",
enableUrlContext: true,
})

// Replace the client with our mock
handler["client"] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice we're accessing private properties using bracket notation here and throughout the tests. While this works, is this the best approach for test isolation? We might want to consider:

  1. Using a proper mocking library that can handle private properties
  2. Exposing a test-friendly method
  3. Making the client property protected instead of private

What do you think would be the cleanest approach?

models: {
generateContentStream: mockGenerateContentStream,
},
} as any

// Setup mock to return an async generator
mockGenerateContentStream.mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Test response" }
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
},
})

const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

const stream = handler.createMessage("System prompt", messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// Verify that generateContentStream was called with urlContext in tools
expect(mockGenerateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
tools: expect.arrayContaining([{ urlContext: {} }]),
}),
}),
)
})

it("should not include urlContext tool when enableUrlContext is false", async () => {
const mockGenerateContentStream = vitest.fn()

const handler = new GeminiHandler({
geminiApiKey: "test-key",
apiModelId: "gemini-1.5-flash",
enableUrlContext: false,
})

// Replace the client with our mock
handler["client"] = {
models: {
generateContentStream: mockGenerateContentStream,
},
} as any

// Setup mock to return an async generator
mockGenerateContentStream.mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Test response" }
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
},
})

const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

const stream = handler.createMessage("System prompt", messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// Verify that generateContentStream was called without urlContext in tools
expect(mockGenerateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.not.objectContaining({
tools: expect.anything(),
}),
}),
)
})

it("should include urlContext in completePrompt when enableUrlContext is true", async () => {
const mockGenerateContent = vitest.fn()

const handler = new GeminiHandler({
geminiApiKey: "test-key",
apiModelId: "gemini-1.5-flash",
enableUrlContext: true,
})

// Replace the client with our mock
handler["client"] = {
models: {
generateContent: mockGenerateContent,
},
} as any

// Mock the response
mockGenerateContent.mockResolvedValue({
text: "Test response",
})

await handler.completePrompt("Test prompt")

// Verify that generateContent was called with urlContext in tools
expect(mockGenerateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
tools: expect.arrayContaining([{ urlContext: {} }]),
}),
}),
)
})
})

describe("VertexHandler", () => {
it("should NOT include urlContext tool even when enableUrlContext is true", async () => {
const mockGenerateContentStream = vitest.fn()

const handler = new VertexHandler({
vertexProjectId: "test-project",
vertexRegion: "us-central1",
apiModelId: "gemini-1.5-pro-001",
enableUrlContext: true, // This should be ignored for Vertex
})

// Replace the client with our mock
handler["client"] = {
models: {
generateContentStream: mockGenerateContentStream,
},
} as any

// Setup mock to return an async generator
mockGenerateContentStream.mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Test response" }
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
},
})

const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

const stream = handler.createMessage("System prompt", messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// Verify that generateContentStream was called WITHOUT urlContext in tools
// even though enableUrlContext was true
const callArgs = mockGenerateContentStream.mock.calls[0][0]
if (callArgs.config.tools) {
// If tools array exists, it should not contain urlContext
expect(callArgs.config.tools).not.toContainEqual({ urlContext: {} })
}
})

it("should NOT include urlContext in completePrompt even when enableUrlContext is true", async () => {
const mockGenerateContent = vitest.fn()

const handler = new VertexHandler({
vertexProjectId: "test-project",
vertexRegion: "us-central1",
apiModelId: "gemini-1.5-pro-001",
enableUrlContext: true, // This should be ignored for Vertex
})

// Replace the client with our mock
handler["client"] = {
models: {
generateContent: mockGenerateContent,
},
} as any

// Mock the response
mockGenerateContent.mockResolvedValue({
text: "Test response",
})

await handler.completePrompt("Test prompt")

// Verify that generateContent was called WITHOUT urlContext in tools
const callArgs = mockGenerateContent.mock.calls[0][0]
if (callArgs.config.tools) {
// If tools array exists, it should not contain urlContext
expect(callArgs.config.tools).not.toContainEqual({ urlContext: {} })
}
})

it("should still include googleSearch tool when enableGrounding is true", async () => {
const mockGenerateContentStream = vitest.fn()

const handler = new VertexHandler({
vertexProjectId: "test-project",
vertexRegion: "us-central1",
apiModelId: "gemini-1.5-pro-001",
enableUrlContext: true, // Should be ignored
enableGrounding: true, // Should be respected
})

// Replace the client with our mock
handler["client"] = {
models: {
generateContentStream: mockGenerateContentStream,
},
} as any

// Setup mock to return an async generator
mockGenerateContentStream.mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Test response" }
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
},
})

const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

const stream = handler.createMessage("System prompt", messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// Verify that googleSearch is included but urlContext is not
const callArgs = mockGenerateContentStream.mock.calls[0][0]
expect(callArgs.config.tools).toContainEqual({ googleSearch: {} })
expect(callArgs.config.tools).not.toContainEqual({ urlContext: {} })
})
})

describe("Integration test - switching between providers", () => {
it("should correctly handle urlContext based on provider type", async () => {
const mockGenerateContentStream = vitest.fn()

// Setup mock to return an async generator
mockGenerateContentStream.mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Test response" }
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
},
})

const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

// Test with Gemini handler
const geminiHandler = new GeminiHandler({
geminiApiKey: "test-key",
apiModelId: "gemini-1.5-flash",
enableUrlContext: true,
})
geminiHandler["client"] = {
models: { generateContentStream: mockGenerateContentStream },
} as any

const geminiStream = geminiHandler.createMessage("System prompt", messages)
for await (const chunk of geminiStream) {
// Consume stream
}

// Verify Gemini includes urlContext
const geminiCall = mockGenerateContentStream.mock.calls[mockGenerateContentStream.mock.calls.length - 1][0]
expect(geminiCall.config.tools).toContainEqual({ urlContext: {} })

// Clear mock calls
mockGenerateContentStream.mockClear()

// Test with Vertex handler using same options
const vertexHandler = new VertexHandler({
vertexProjectId: "test-project",
vertexRegion: "us-central1",
apiModelId: "gemini-1.5-pro-001",
enableUrlContext: true, // Same setting, but should be ignored
})
vertexHandler["client"] = {
models: { generateContentStream: mockGenerateContentStream },
} as any

const vertexStream = vertexHandler.createMessage("System prompt", messages)
for await (const chunk of vertexStream) {
// Consume stream
}

// Verify Vertex does NOT include urlContext
const vertexCall = mockGenerateContentStream.mock.calls[0][0]
if (vertexCall.config.tools) {
expect(vertexCall.config.tools).not.toContainEqual({ urlContext: {} })
}
})
})
})
8 changes: 6 additions & 2 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ type GeminiHandlerOptions = ApiHandlerOptions & {

export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private isVertex: boolean
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a comment here explaining why this differentiation is necessary:

Suggested change
private isVertex: boolean
// Track whether this instance is for Vertex AI to prevent unsupported parameters like urlContext
private isVertex: boolean

This would help future maintainers understand the purpose of this flag.


private client: GoogleGenAI

constructor({ isVertex, ...options }: GeminiHandlerOptions) {
super()

this.options = options
this.isVertex = isVertex ?? false
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the nullish coalescing operator necessary here? Since TypeScript already has the optional parameter, you could make the default more explicit in the constructor signature:

Suggested change
this.isVertex = isVertex ?? false
constructor({ isVertex = false, ...options }: GeminiHandlerOptions) {

Then just use:

Suggested change
this.isVertex = isVertex ?? false
this.isVertex = isVertex


const project = this.options.vertexProjectId ?? "not-provided"
const location = this.options.vertexRegion ?? "not-provided"
Expand Down Expand Up @@ -70,7 +72,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
const contents = messages.map(convertAnthropicMessageToGemini)

const tools: GenerateContentConfig["tools"] = []
if (this.options.enableUrlContext) {
// urlContext is only supported in regular Gemini, not Vertex AI
if (this.options.enableUrlContext && !this.isVertex) {
tools.push({ urlContext: {} })
}

Expand Down Expand Up @@ -214,7 +217,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
const { id: model } = this.getModel()

const tools: GenerateContentConfig["tools"] = []
if (this.options.enableUrlContext) {
// urlContext is only supported in regular Gemini, not Vertex AI
if (this.options.enableUrlContext && !this.isVertex) {
tools.push({ urlContext: {} })
}
if (this.options.enableGrounding) {
Expand Down
Loading