-
Notifications
You must be signed in to change notification settings - Fork 2.5k
fix: prevent urlContext from being added to Vertex AI requests #7969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,302 @@ | ||
| // npx vitest run src/api/providers/__tests__/vertex-gemini-urlcontext.spec.ts | ||
|
|
||
| import { Anthropic } from "@anthropic-ai/sdk" | ||
| import { GeminiHandler } from "../gemini" | ||
| import { VertexHandler } from "../vertex" | ||
|
|
||
| describe("Vertex vs Gemini urlContext handling", () => { | ||
| describe("GeminiHandler", () => { | ||
| it("should include urlContext tool when enableUrlContext is true", async () => { | ||
| const mockGenerateContentStream = vitest.fn() | ||
|
|
||
| const handler = new GeminiHandler({ | ||
| geminiApiKey: "test-key", | ||
| apiModelId: "gemini-1.5-flash", | ||
| enableUrlContext: true, | ||
| }) | ||
|
|
||
| // Replace the client with our mock | ||
| handler["client"] = { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I notice we're accessing private properties using bracket notation here and throughout the tests. While this works, is this the best approach for test isolation? We might want to consider:
What do you think would be the cleanest approach? |
||
| models: { | ||
| generateContentStream: mockGenerateContentStream, | ||
| }, | ||
| } as any | ||
|
|
||
| // Setup mock to return an async generator | ||
| mockGenerateContentStream.mockResolvedValue({ | ||
| [Symbol.asyncIterator]: async function* () { | ||
| yield { text: "Test response" } | ||
| yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } | ||
| }, | ||
| }) | ||
|
|
||
| const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] | ||
|
|
||
| const stream = handler.createMessage("System prompt", messages) | ||
| const chunks = [] | ||
| for await (const chunk of stream) { | ||
| chunks.push(chunk) | ||
| } | ||
|
|
||
| // Verify that generateContentStream was called with urlContext in tools | ||
| expect(mockGenerateContentStream).toHaveBeenCalledWith( | ||
| expect.objectContaining({ | ||
| config: expect.objectContaining({ | ||
| tools: expect.arrayContaining([{ urlContext: {} }]), | ||
| }), | ||
| }), | ||
| ) | ||
| }) | ||
|
|
||
| it("should not include urlContext tool when enableUrlContext is false", async () => { | ||
| const mockGenerateContentStream = vitest.fn() | ||
|
|
||
| const handler = new GeminiHandler({ | ||
| geminiApiKey: "test-key", | ||
| apiModelId: "gemini-1.5-flash", | ||
| enableUrlContext: false, | ||
| }) | ||
|
|
||
| // Replace the client with our mock | ||
| handler["client"] = { | ||
| models: { | ||
| generateContentStream: mockGenerateContentStream, | ||
| }, | ||
| } as any | ||
|
|
||
| // Setup mock to return an async generator | ||
| mockGenerateContentStream.mockResolvedValue({ | ||
| [Symbol.asyncIterator]: async function* () { | ||
| yield { text: "Test response" } | ||
| yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } | ||
| }, | ||
| }) | ||
|
|
||
| const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] | ||
|
|
||
| const stream = handler.createMessage("System prompt", messages) | ||
| const chunks = [] | ||
| for await (const chunk of stream) { | ||
| chunks.push(chunk) | ||
| } | ||
|
|
||
| // Verify that generateContentStream was called without urlContext in tools | ||
| expect(mockGenerateContentStream).toHaveBeenCalledWith( | ||
| expect.objectContaining({ | ||
| config: expect.not.objectContaining({ | ||
| tools: expect.anything(), | ||
| }), | ||
| }), | ||
| ) | ||
| }) | ||
|
|
||
| it("should include urlContext in completePrompt when enableUrlContext is true", async () => { | ||
| const mockGenerateContent = vitest.fn() | ||
|
|
||
| const handler = new GeminiHandler({ | ||
| geminiApiKey: "test-key", | ||
| apiModelId: "gemini-1.5-flash", | ||
| enableUrlContext: true, | ||
| }) | ||
|
|
||
| // Replace the client with our mock | ||
| handler["client"] = { | ||
| models: { | ||
| generateContent: mockGenerateContent, | ||
| }, | ||
| } as any | ||
|
|
||
| // Mock the response | ||
| mockGenerateContent.mockResolvedValue({ | ||
| text: "Test response", | ||
| }) | ||
|
|
||
| await handler.completePrompt("Test prompt") | ||
|
|
||
| // Verify that generateContent was called with urlContext in tools | ||
| expect(mockGenerateContent).toHaveBeenCalledWith( | ||
| expect.objectContaining({ | ||
| config: expect.objectContaining({ | ||
| tools: expect.arrayContaining([{ urlContext: {} }]), | ||
| }), | ||
| }), | ||
| ) | ||
| }) | ||
| }) | ||
|
|
||
| describe("VertexHandler", () => { | ||
| it("should NOT include urlContext tool even when enableUrlContext is true", async () => { | ||
| const mockGenerateContentStream = vitest.fn() | ||
|
|
||
| const handler = new VertexHandler({ | ||
| vertexProjectId: "test-project", | ||
| vertexRegion: "us-central1", | ||
| apiModelId: "gemini-1.5-pro-001", | ||
| enableUrlContext: true, // This should be ignored for Vertex | ||
| }) | ||
|
|
||
| // Replace the client with our mock | ||
| handler["client"] = { | ||
| models: { | ||
| generateContentStream: mockGenerateContentStream, | ||
| }, | ||
| } as any | ||
|
|
||
| // Setup mock to return an async generator | ||
| mockGenerateContentStream.mockResolvedValue({ | ||
| [Symbol.asyncIterator]: async function* () { | ||
| yield { text: "Test response" } | ||
| yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } | ||
| }, | ||
| }) | ||
|
|
||
| const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] | ||
|
|
||
| const stream = handler.createMessage("System prompt", messages) | ||
| const chunks = [] | ||
| for await (const chunk of stream) { | ||
| chunks.push(chunk) | ||
| } | ||
|
|
||
| // Verify that generateContentStream was called WITHOUT urlContext in tools | ||
| // even though enableUrlContext was true | ||
| const callArgs = mockGenerateContentStream.mock.calls[0][0] | ||
| if (callArgs.config.tools) { | ||
| // If tools array exists, it should not contain urlContext | ||
| expect(callArgs.config.tools).not.toContainEqual({ urlContext: {} }) | ||
| } | ||
| }) | ||
|
|
||
| it("should NOT include urlContext in completePrompt even when enableUrlContext is true", async () => { | ||
| const mockGenerateContent = vitest.fn() | ||
|
|
||
| const handler = new VertexHandler({ | ||
| vertexProjectId: "test-project", | ||
| vertexRegion: "us-central1", | ||
| apiModelId: "gemini-1.5-pro-001", | ||
| enableUrlContext: true, // This should be ignored for Vertex | ||
| }) | ||
|
|
||
| // Replace the client with our mock | ||
| handler["client"] = { | ||
| models: { | ||
| generateContent: mockGenerateContent, | ||
| }, | ||
| } as any | ||
|
|
||
| // Mock the response | ||
| mockGenerateContent.mockResolvedValue({ | ||
| text: "Test response", | ||
| }) | ||
|
|
||
| await handler.completePrompt("Test prompt") | ||
|
|
||
| // Verify that generateContent was called WITHOUT urlContext in tools | ||
| const callArgs = mockGenerateContent.mock.calls[0][0] | ||
| if (callArgs.config.tools) { | ||
| // If tools array exists, it should not contain urlContext | ||
| expect(callArgs.config.tools).not.toContainEqual({ urlContext: {} }) | ||
| } | ||
| }) | ||
|
|
||
| it("should still include googleSearch tool when enableGrounding is true", async () => { | ||
| const mockGenerateContentStream = vitest.fn() | ||
|
|
||
| const handler = new VertexHandler({ | ||
| vertexProjectId: "test-project", | ||
| vertexRegion: "us-central1", | ||
| apiModelId: "gemini-1.5-pro-001", | ||
| enableUrlContext: true, // Should be ignored | ||
| enableGrounding: true, // Should be respected | ||
| }) | ||
|
|
||
| // Replace the client with our mock | ||
| handler["client"] = { | ||
| models: { | ||
| generateContentStream: mockGenerateContentStream, | ||
| }, | ||
| } as any | ||
|
|
||
| // Setup mock to return an async generator | ||
| mockGenerateContentStream.mockResolvedValue({ | ||
| [Symbol.asyncIterator]: async function* () { | ||
| yield { text: "Test response" } | ||
| yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } | ||
| }, | ||
| }) | ||
|
|
||
| const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] | ||
|
|
||
| const stream = handler.createMessage("System prompt", messages) | ||
| const chunks = [] | ||
| for await (const chunk of stream) { | ||
| chunks.push(chunk) | ||
| } | ||
|
|
||
| // Verify that googleSearch is included but urlContext is not | ||
| const callArgs = mockGenerateContentStream.mock.calls[0][0] | ||
| expect(callArgs.config.tools).toContainEqual({ googleSearch: {} }) | ||
| expect(callArgs.config.tools).not.toContainEqual({ urlContext: {} }) | ||
| }) | ||
| }) | ||
|
|
||
| describe("Integration test - switching between providers", () => { | ||
| it("should correctly handle urlContext based on provider type", async () => { | ||
| const mockGenerateContentStream = vitest.fn() | ||
|
|
||
| // Setup mock to return an async generator | ||
| mockGenerateContentStream.mockResolvedValue({ | ||
| [Symbol.asyncIterator]: async function* () { | ||
| yield { text: "Test response" } | ||
| yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } | ||
| }, | ||
| }) | ||
|
|
||
| const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] | ||
|
|
||
| // Test with Gemini handler | ||
| const geminiHandler = new GeminiHandler({ | ||
| geminiApiKey: "test-key", | ||
| apiModelId: "gemini-1.5-flash", | ||
| enableUrlContext: true, | ||
| }) | ||
| geminiHandler["client"] = { | ||
| models: { generateContentStream: mockGenerateContentStream }, | ||
| } as any | ||
|
|
||
| const geminiStream = geminiHandler.createMessage("System prompt", messages) | ||
| for await (const chunk of geminiStream) { | ||
| // Consume stream | ||
| } | ||
|
|
||
| // Verify Gemini includes urlContext | ||
| const geminiCall = mockGenerateContentStream.mock.calls[mockGenerateContentStream.mock.calls.length - 1][0] | ||
| expect(geminiCall.config.tools).toContainEqual({ urlContext: {} }) | ||
|
|
||
| // Clear mock calls | ||
| mockGenerateContentStream.mockClear() | ||
|
|
||
| // Test with Vertex handler using same options | ||
| const vertexHandler = new VertexHandler({ | ||
| vertexProjectId: "test-project", | ||
| vertexRegion: "us-central1", | ||
| apiModelId: "gemini-1.5-pro-001", | ||
| enableUrlContext: true, // Same setting, but should be ignored | ||
| }) | ||
| vertexHandler["client"] = { | ||
| models: { generateContentStream: mockGenerateContentStream }, | ||
| } as any | ||
|
|
||
| const vertexStream = vertexHandler.createMessage("System prompt", messages) | ||
| for await (const chunk of vertexStream) { | ||
| // Consume stream | ||
| } | ||
|
|
||
| // Verify Vertex does NOT include urlContext | ||
| const vertexCall = mockGenerateContentStream.mock.calls[0][0] | ||
| if (vertexCall.config.tools) { | ||
| expect(vertexCall.config.tools).not.toContainEqual({ urlContext: {} }) | ||
| } | ||
| }) | ||
| }) | ||
| }) | ||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -27,13 +27,15 @@ type GeminiHandlerOptions = ApiHandlerOptions & { | |||||||||
|
|
||||||||||
| export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { | ||||||||||
| protected options: ApiHandlerOptions | ||||||||||
| private isVertex: boolean | ||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a comment here explaining why this differentiation is necessary:
Suggested change
This would help future maintainers understand the purpose of this flag. |
||||||||||
|
|
||||||||||
| private client: GoogleGenAI | ||||||||||
|
|
||||||||||
| constructor({ isVertex, ...options }: GeminiHandlerOptions) { | ||||||||||
| super() | ||||||||||
|
|
||||||||||
| this.options = options | ||||||||||
| this.isVertex = isVertex ?? false | ||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the nullish coalescing operator necessary here? Since TypeScript already has the optional parameter, you could make the default more explicit in the constructor signature:
Suggested change
Then just use:
Suggested change
|
||||||||||
|
|
||||||||||
| const project = this.options.vertexProjectId ?? "not-provided" | ||||||||||
| const location = this.options.vertexRegion ?? "not-provided" | ||||||||||
|
|
@@ -70,7 +72,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl | |||||||||
| const contents = messages.map(convertAnthropicMessageToGemini) | ||||||||||
|
|
||||||||||
| const tools: GenerateContentConfig["tools"] = [] | ||||||||||
| if (this.options.enableUrlContext) { | ||||||||||
| // urlContext is only supported in regular Gemini, not Vertex AI | ||||||||||
| if (this.options.enableUrlContext && !this.isVertex) { | ||||||||||
| tools.push({ urlContext: {} }) | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
@@ -214,7 +217,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl | |||||||||
| const { id: model } = this.getModel() | ||||||||||
|
|
||||||||||
| const tools: GenerateContentConfig["tools"] = [] | ||||||||||
| if (this.options.enableUrlContext) { | ||||||||||
| // urlContext is only supported in regular Gemini, not Vertex AI | ||||||||||
| if (this.options.enableUrlContext && !this.isVertex) { | ||||||||||
| tools.push({ urlContext: {} }) | ||||||||||
| } | ||||||||||
| if (this.options.enableGrounding) { | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test file comment could be more descriptive about what this test suite validates. Consider: