18
18
19
19
import static com .google .common .truth .Truth .assertThat ;
20
20
import static com .google .common .truth .Truth .assertWithMessage ;
21
-
21
+ import static org .mockito .ArgumentMatchers .any ;
22
+ import static org .mockito .ArgumentMatchers .anyString ;
23
+ import static org .mockito .Mockito .RETURNS_SELF ;
24
+ import static org .mockito .Mockito .mock ;
25
+ import static org .mockito .Mockito .mockStatic ;
26
+ import static org .mockito .Mockito .times ;
27
+ import static org .mockito .Mockito .verify ;
28
+ import static org .mockito .Mockito .when ;
29
+
30
+ import com .google .genai .Client ;
31
+ import com .google .genai .Models ;
32
+ import com .google .genai .types .GenerateContentConfig ;
33
+ import com .google .genai .types .GenerateContentResponse ;
22
34
import java .io .ByteArrayOutputStream ;
23
35
import java .io .IOException ;
24
36
import java .io .PrintStream ;
37
+ import java .lang .reflect .Field ;
25
38
import org .junit .After ;
26
39
import org .junit .Before ;
27
40
import org .junit .BeforeClass ;
28
41
import org .junit .Test ;
29
42
import org .junit .runner .RunWith ;
30
43
import org .junit .runners .JUnit4 ;
44
+ import org .mockito .MockedStatic ;
45
+
31
46
32
47
@ RunWith (JUnit4 .class )
33
48
public class ToolsIT {
@@ -105,4 +120,42 @@ public void testToolsGoogleSearchWithText() {
105
120
assertThat (response ).isNotEmpty ();
106
121
}
107
122
123
+ @ Test
124
+ public void testToolsVaisWithText () throws NoSuchFieldException , IllegalAccessException {
125
+ String response = "The process for making an appointment to renew your driver's license"
126
+ + " varies depending on your location." ;
127
+
128
+ String datastore =
129
+ String .format (
130
+ "projects/%s/locations/global/collections/default_collection/"
131
+ + "dataStores/grounding-test-datastore" ,
132
+ PROJECT_ID );
133
+
134
+ Client .Builder mockedBuilder = mock (Client .Builder .class , RETURNS_SELF );
135
+ Client mockedClient = mock (Client .class );
136
+ Models mockedModels = mock (Models .class );
137
+ GenerateContentResponse mockedResponse = mock (GenerateContentResponse .class );
138
+
139
+ try (MockedStatic <Client > mockedStatic = mockStatic (Client .class )) {
140
+ mockedStatic .when (Client ::builder ).thenReturn (mockedBuilder );
141
+ when (mockedBuilder .build ()).thenReturn (mockedClient );
142
+
143
+ // Using reflection because 'models' is a final field and cannot be mockable directly
144
+ Field field = Client .class .getDeclaredField ("models" );
145
+ field .setAccessible (true );
146
+ field .set (mockedClient , mockedModels );
147
+
148
+ when (mockedClient .models .generateContent (
149
+ anyString (), anyString (), any (GenerateContentConfig .class )))
150
+ .thenReturn (mockedResponse );
151
+ when (mockedResponse .text ()).thenReturn (response );
152
+
153
+ String generatedResponse = ToolsVaisWithText .generateContent (GEMINI_FLASH , datastore );
154
+
155
+ verify (mockedClient .models , times (1 ))
156
+ .generateContent (anyString (), anyString (), any (GenerateContentConfig .class ));
157
+ assertThat (generatedResponse ).isNotEmpty ();
158
+ assertThat (response ).isEqualTo (generatedResponse );
159
+ }
160
+ }
108
161
}
0 commit comments