2121import com .google .firebase .concurrent .FirebaseExecutors ;
2222import com .google .firebase .vertexai .FirebaseVertexAI ;
2323import com .google .firebase .vertexai .GenerativeModel ;
24+ import com .google .firebase .vertexai .LiveGenerativeModel ;
2425import com .google .firebase .vertexai .java .ChatFutures ;
2526import com .google .firebase .vertexai .java .GenerativeModelFutures ;
27+ import com .google .firebase .vertexai .java .LiveModelFutures ;
28+ import com .google .firebase .vertexai .java .LiveSessionFutures ;
2629import com .google .firebase .vertexai .type .BlockReason ;
2730import com .google .firebase .vertexai .type .Candidate ;
2831import com .google .firebase .vertexai .type .Citation ;
3336import com .google .firebase .vertexai .type .FileDataPart ;
3437import com .google .firebase .vertexai .type .FinishReason ;
3538import com .google .firebase .vertexai .type .FunctionCallPart ;
39+ import com .google .firebase .vertexai .type .FunctionResponsePart ;
3640import com .google .firebase .vertexai .type .GenerateContentResponse ;
41+ import com .google .firebase .vertexai .type .GenerationConfig ;
3742import com .google .firebase .vertexai .type .HarmCategory ;
3843import com .google .firebase .vertexai .type .HarmProbability ;
3944import com .google .firebase .vertexai .type .HarmSeverity ;
4045import com .google .firebase .vertexai .type .ImagePart ;
4146import com .google .firebase .vertexai .type .InlineDataPart ;
47+ import com .google .firebase .vertexai .type .LiveContentResponse ;
48+ import com .google .firebase .vertexai .type .LiveGenerationConfig ;
49+ import com .google .firebase .vertexai .type .MediaData ;
4250import com .google .firebase .vertexai .type .ModalityTokenCount ;
4351import com .google .firebase .vertexai .type .Part ;
4452import com .google .firebase .vertexai .type .PromptFeedback ;
53+ import com .google .firebase .vertexai .type .ResponseModality ;
4554import com .google .firebase .vertexai .type .SafetyRating ;
55+ import com .google .firebase .vertexai .type .SpeechConfig ;
4656import com .google .firebase .vertexai .type .TextPart ;
4757import com .google .firebase .vertexai .type .UsageMetadata ;
58+ import com .google .firebase .vertexai .type .Voices ;
4859import java .util .Calendar ;
4960import java .util .List ;
5061import java .util .Map ;
5162import java .util .concurrent .Executor ;
5263import kotlinx .serialization .json .JsonElement ;
5364import kotlinx .serialization .json .JsonNull ;
65+ import kotlinx .serialization .json .JsonObject ;
5466import org .junit .Assert ;
5567import org .reactivestreams .Publisher ;
5668import org .reactivestreams .Subscriber ;
@@ -63,9 +75,31 @@ public class JavaCompileTests {
6375
6476 public void initializeJava () throws Exception {
6577 FirebaseVertexAI vertex = FirebaseVertexAI .getInstance ();
66- GenerativeModel model = vertex .generativeModel ("fake-model-name" );
78+ GenerativeModel model = vertex .generativeModel ("fake-model-name" , getConfig ());
79+ LiveGenerativeModel live = vertex .liveModel ("fake-model-name" , getLiveConfig ());
6780 GenerativeModelFutures futures = GenerativeModelFutures .from (model );
81+ LiveModelFutures liveFutures = LiveModelFutures .from (live );
6882 testFutures (futures );
83+ testLiveFutures (liveFutures );
84+ }
85+
86+ private GenerationConfig getConfig () {
87+ return new GenerationConfig .Builder ().build ();
88+ // TODO b/406558430 GenerationConfig.Builder.setParts returns void
89+ }
90+
91+ private LiveGenerationConfig getLiveConfig () {
92+ return new LiveGenerationConfig .Builder ()
93+ .setTopK (10 )
94+ .setTopP (11.0F )
95+ .setTemperature (32.0F )
96+ .setCandidateCount (1 )
97+ .setMaxOutputTokens (0xCAFEBABE )
98+ .setFrequencyPenalty (1.0F )
99+ .setPresencePenalty (2.0F )
100+ .setResponseModality (ResponseModality .AUDIO )
101+ .setSpeechConfig (new SpeechConfig (Voices .AOEDE ))
102+ .build ();
69103 }
70104
71105 private void testFutures (GenerativeModelFutures futures ) throws Exception {
@@ -236,4 +270,62 @@ public void validateUsageMetadata(UsageMetadata metadata) {
236270 }
237271 }
238272 }
273+
274+ private void testLiveFutures (LiveModelFutures futures ) throws Exception {
275+ LiveSessionFutures session = futures .connect ().get ();
276+ session
277+ .receive ()
278+ .subscribe (
279+ new Subscriber <LiveContentResponse >() {
280+ @ Override
281+ public void onSubscribe (Subscription s ) {
282+ s .request (Long .MAX_VALUE );
283+ }
284+
285+ @ Override
286+ public void onNext (LiveContentResponse response ) {
287+ validateLiveContentResponse (response );
288+ }
289+
290+ @ Override
291+ public void onError (Throwable t ) {
292+ // Ignore
293+ }
294+
295+ @ Override
296+ public void onComplete () {
297+ // Also ignore
298+ }
299+ });
300+
301+ session .send ("Fake message" );
302+ session .send (new Content .Builder ().addText ("Fake message" ).build ());
303+
304+ byte [] bytes = new byte [] {(byte ) 0xCA , (byte ) 0xFE , (byte ) 0xBA , (byte ) 0xBE };
305+ session .sendMediaStream (List .of (new MediaData (bytes , "image/jxl" )));
306+
307+ FunctionResponsePart functionResponse =
308+ new FunctionResponsePart ("myFunction" , new JsonObject (Map .of ()));
309+ session .sendFunctionResponse (List .of (functionResponse , functionResponse ));
310+
311+ session .startAudioConversation (part -> functionResponse );
312+ session .startAudioConversation ();
313+ session .stopAudioConversation ();
314+ session .stopReceiving ();
315+ session .close ();
316+ }
317+
318+ private void validateLiveContentResponse (LiveContentResponse response ) {
319+ //int status = response.getStatus();
320+ //Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL());
321+ //Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED());
322+ //Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE());
323+ // TODO b/412743328 LiveContentResponse.Status inaccessible for Java users
324+ Content data = response .getData ();
325+ if (data != null ) {
326+ validateContent (data );
327+ }
328+ String text = response .getText ();
329+ validateFunctionCalls (response .getFunctionCalls ());
330+ }
239331}
0 commit comments