@@ -5,7 +5,7 @@ import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
55
66import { VertexHandler } from "../vertex"
77import { ApiStreamChunk } from "../../transform/stream"
8- import { VertexAI } from "@google-cloud/vertexai "
8+ import { GeminiHandler } from "../gemini "
99
1010// Mock Vertex SDK
1111jest . mock ( "@anthropic-ai/vertex-sdk" , ( ) => ( {
@@ -49,58 +49,40 @@ jest.mock("@anthropic-ai/vertex-sdk", () => ({
4949 } ) ) ,
5050} ) )
5151
52- // Mock Vertex Gemini SDK
53- jest . mock ( "@google-cloud/vertexai" , ( ) => {
54- const mockGenerateContentStream = jest . fn ( ) . mockImplementation ( ( ) => {
55- return {
56- stream : {
57- async * [ Symbol . asyncIterator ] ( ) {
58- yield {
59- candidates : [
60- {
61- content : {
62- parts : [ { text : "Test Gemini response" } ] ,
63- } ,
64- } ,
65- ] ,
66- }
67- } ,
52+ jest . mock ( "../gemini" , ( ) => {
53+ const mockGeminiHandler = jest . fn ( )
54+
55+ mockGeminiHandler . prototype . createMessage = jest . fn ( ) . mockImplementation ( async function * ( ) {
56+ const mockStream : ApiStreamChunk [ ] = [
57+ {
58+ type : "usage" ,
59+ inputTokens : 10 ,
60+ outputTokens : 0 ,
6861 } ,
69- response : {
70- usageMetadata : {
71- promptTokenCount : 5 ,
72- candidatesTokenCount : 10 ,
73- } ,
62+ {
63+ type : "text" ,
64+ text : "Gemini response part 1" ,
7465 } ,
75- }
76- } )
77-
78- const mockGenerateContent = jest . fn ( ) . mockResolvedValue ( {
79- response : {
80- candidates : [
81- {
82- content : {
83- parts : [ { text : "Test Gemini response" } ] ,
84- } ,
85- } ,
86- ] ,
87- } ,
88- } )
66+ {
67+ type : "text" ,
68+ text : " part 2" ,
69+ } ,
70+ {
71+ type : "usage" ,
72+ inputTokens : 0 ,
73+ outputTokens : 5 ,
74+ } ,
75+ ]
8976
90- const mockGenerativeModel = jest . fn ( ) . mockImplementation ( ( ) => {
91- return {
92- generateContentStream : mockGenerateContentStream ,
93- generateContent : mockGenerateContent ,
77+ for ( const chunk of mockStream ) {
78+ yield chunk
9479 }
9580 } )
9681
82+ mockGeminiHandler . prototype . completePrompt = jest . fn ( ) . mockResolvedValue ( "Test Gemini response" )
83+
9784 return {
98- VertexAI : jest . fn ( ) . mockImplementation ( ( ) => {
99- return {
100- getGenerativeModel : mockGenerativeModel ,
101- }
102- } ) ,
103- GenerativeModel : mockGenerativeModel ,
85+ GeminiHandler : mockGeminiHandler ,
10486 }
10587} )
10688
@@ -128,9 +110,11 @@ describe("VertexHandler", () => {
128110 vertexRegion : "us-central1" ,
129111 } )
130112
131- expect ( VertexAI ) . toHaveBeenCalledWith ( {
132- project : "test-project" ,
133- location : "us-central1" ,
113+ expect ( GeminiHandler ) . toHaveBeenCalledWith ( {
114+ isVertex : true ,
115+ apiModelId : "gemini-1.5-pro-001" ,
116+ vertexProjectId : "test-project" ,
117+ vertexRegion : "us-central1" ,
134118 } )
135119 } )
136120
@@ -270,48 +254,48 @@ describe("VertexHandler", () => {
270254 } )
271255
272256 it ( "should handle streaming responses correctly for Gemini" , async ( ) => {
273- const mockGemini = require ( "@google-cloud/vertexai" )
274- const mockGenerateContentStream = mockGemini . VertexAI ( ) . getGenerativeModel ( ) . generateContentStream
275257 handler = new VertexHandler ( {
276258 apiModelId : "gemini-1.5-pro-001" ,
277259 vertexProjectId : "test-project" ,
278260 vertexRegion : "us-central1" ,
279261 } )
280262
281- const stream = handler . createMessage ( systemPrompt , mockMessages )
263+ const mockCacheKey = "cacheKey"
264+ const mockGeminiHandlerInstance = ( GeminiHandler as jest . Mock ) . mock . instances [ 0 ]
265+
266+ const stream = handler . createMessage ( systemPrompt , mockMessages , mockCacheKey )
267+
282268 const chunks : ApiStreamChunk [ ] = [ ]
283269
284270 for await ( const chunk of stream ) {
285271 chunks . push ( chunk )
286272 }
287273
288- expect ( chunks . length ) . toBe ( 2 )
274+ expect ( chunks . length ) . toBe ( 4 )
289275 expect ( chunks [ 0 ] ) . toEqual ( {
290- type : "text" ,
291- text : "Test Gemini response" ,
276+ type : "usage" ,
277+ inputTokens : 10 ,
278+ outputTokens : 0 ,
292279 } )
293280 expect ( chunks [ 1 ] ) . toEqual ( {
281+ type : "text" ,
282+ text : "Gemini response part 1" ,
283+ } )
284+ expect ( chunks [ 2 ] ) . toEqual ( {
285+ type : "text" ,
286+ text : " part 2" ,
287+ } )
288+ expect ( chunks [ 3 ] ) . toEqual ( {
294289 type : "usage" ,
295- inputTokens : 5 ,
296- outputTokens : 10 ,
290+ inputTokens : 0 ,
291+ outputTokens : 5 ,
297292 } )
298293
299- expect ( mockGenerateContentStream ) . toHaveBeenCalledWith ( {
300- contents : [
301- {
302- role : "user" ,
303- parts : [ { text : "Hello" } ] ,
304- } ,
305- {
306- role : "model" ,
307- parts : [ { text : "Hi there!" } ] ,
308- } ,
309- ] ,
310- generationConfig : {
311- maxOutputTokens : 8192 ,
312- temperature : 0 ,
313- } ,
314- } )
294+ expect ( mockGeminiHandlerInstance . createMessage ) . toHaveBeenCalledWith (
295+ systemPrompt ,
296+ mockMessages ,
297+ mockCacheKey ,
298+ )
315299 } )
316300
317301 it ( "should handle multiple content blocks with line breaks for Claude" , async ( ) => {
@@ -753,9 +737,6 @@ describe("VertexHandler", () => {
753737 } )
754738
755739 it ( "should complete prompt successfully for Gemini" , async ( ) => {
756- const mockGemini = require ( "@google-cloud/vertexai" )
757- const mockGenerateContent = mockGemini . VertexAI ( ) . getGenerativeModel ( ) . generateContent
758-
759740 handler = new VertexHandler ( {
760741 apiModelId : "gemini-1.5-pro-001" ,
761742 vertexProjectId : "test-project" ,
@@ -764,13 +745,9 @@ describe("VertexHandler", () => {
764745
765746 const result = await handler . completePrompt ( "Test prompt" )
766747 expect ( result ) . toBe ( "Test Gemini response" )
767- expect ( mockGenerateContent ) . toHaveBeenCalled ( )
768- expect ( mockGenerateContent ) . toHaveBeenCalledWith ( {
769- contents : [ { role : "user" , parts : [ { text : "Test prompt" } ] } ] ,
770- generationConfig : {
771- temperature : 0 ,
772- } ,
773- } )
748+
749+ const mockGeminiHandlerInstance = ( GeminiHandler as jest . Mock ) . mock . instances [ 0 ]
750+ expect ( mockGeminiHandlerInstance . completePrompt ) . toHaveBeenCalledWith ( "Test prompt" )
774751 } )
775752
776753 it ( "should handle API errors for Claude" , async ( ) => {
@@ -790,17 +767,17 @@ describe("VertexHandler", () => {
790767 } )
791768
792769 it ( "should handle API errors for Gemini" , async ( ) => {
793- const mockGemini = require ( "@google-cloud/vertexai" )
794- const mockGenerateContent = mockGemini . VertexAI ( ) . getGenerativeModel ( ) . generateContent
795- mockGenerateContent . mockRejectedValue ( new Error ( "Vertex API error" ) )
770+ const mockGeminiHandlerInstance = ( GeminiHandler as jest . Mock ) . mock . instances [ 0 ]
771+ mockGeminiHandlerInstance . completePrompt . mockRejectedValue ( new Error ( "Vertex API error" ) )
772+
796773 handler = new VertexHandler ( {
797774 apiModelId : "gemini-1.5-pro-001" ,
798775 vertexProjectId : "test-project" ,
799776 vertexRegion : "us-central1" ,
800777 } )
801778
802779 await expect ( handler . completePrompt ( "Test prompt" ) ) . rejects . toThrow (
803- "Vertex completion error: Vertex API error" ,
780+ "Vertex API error" , // Expecting the raw error message from the mock
804781 )
805782 } )
806783
@@ -837,19 +814,9 @@ describe("VertexHandler", () => {
837814 } )
838815
839816 it ( "should handle empty response for Gemini" , async ( ) => {
840- const mockGemini = require ( "@google-cloud/vertexai" )
841- const mockGenerateContent = mockGemini . VertexAI ( ) . getGenerativeModel ( ) . generateContent
842- mockGenerateContent . mockResolvedValue ( {
843- response : {
844- candidates : [
845- {
846- content : {
847- parts : [ { text : "" } ] ,
848- } ,
849- } ,
850- ] ,
851- } ,
852- } )
817+ const mockGeminiHandlerInstance = ( GeminiHandler as jest . Mock ) . mock . instances [ 0 ]
818+ mockGeminiHandlerInstance . completePrompt . mockResolvedValue ( "" )
819+
853820 handler = new VertexHandler ( {
854821 apiModelId : "gemini-1.5-pro-001" ,
855822 vertexProjectId : "test-project" ,
0 commit comments