Skip to content

Commit 88453fd

Browse files
chore: [vertexai] pass in immutable object in generateContentStream private method (googleapis#10500)
PiperOrigin-RevId: 613330411 Co-authored-by: Jaycee Li <[email protected]>
1 parent 3166f32 commit 88453fd

File tree

3 files changed

+90
-42
lines changed

3 files changed

+90
-42
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
968968
List<Content> contents, GenerationConfig generationConfig, List<SafetySetting> safetySettings)
969969
throws IOException {
970970
GenerateContentRequest.Builder requestBuilder =
971-
GenerateContentRequest.newBuilder().addAllContents(contents);
971+
GenerateContentRequest.newBuilder().setModel(this.resourceName).addAllContents(contents);
972972
if (generationConfig != null) {
973973
requestBuilder.setGenerationConfig(generationConfig);
974974
} else if (this.generationConfig != null) {
@@ -982,7 +982,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
982982
if (this.tools != null) {
983983
requestBuilder.addAllTools(this.tools);
984984
}
985-
return generateContentStream(requestBuilder);
985+
return generateContentStream(requestBuilder.build());
986986
}
987987

988988
/**
@@ -1000,7 +1000,7 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
10001000
public ResponseStream<GenerateContentResponse> generateContentStream(
10011001
List<Content> contents, GenerateContentConfig config) throws IOException {
10021002
GenerateContentRequest.Builder requestBuilder =
1003-
GenerateContentRequest.newBuilder().addAllContents(contents);
1003+
GenerateContentRequest.newBuilder().setModel(this.resourceName).addAllContents(contents);
10041004
if (config.getGenerationConfig() != null) {
10051005
requestBuilder.setGenerationConfig(config.getGenerationConfig());
10061006
} else if (this.generationConfig != null) {
@@ -1017,42 +1017,36 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
10171017
requestBuilder.addAllTools(this.tools);
10181018
}
10191019

1020-
return generateContentStream(requestBuilder);
1020+
return generateContentStream(requestBuilder.build());
10211021
}
10221022

10231023
/**
10241024
* A base generateContentStream method that will be used internally.
10251025
*
1026-
* @param requestBuilder a {@link com.google.cloud.vertexai.api.GenerateContentRequest.Builder}
1027-
* instance
1026+
* @param request a {@link com.google.cloud.vertexai.api.GenerateContentRequest} instance
10281027
* @return a {@link ResponseStream} that contains a streaming of {@link
10291028
* com.google.cloud.vertexai.api.GenerateContentResponse}
10301029
* @throws IOException if an I/O error occurs while making the API call
10311030
*/
10321031
private ResponseStream<GenerateContentResponse> generateContentStream(
1033-
GenerateContentRequest.Builder requestBuilder) throws IOException {
1034-
GenerateContentRequest request = requestBuilder.setModel(this.resourceName).build();
1035-
ResponseStream<GenerateContentResponse> responseStream = null;
1032+
GenerateContentRequest request) throws IOException {
10361033
if (this.transport == Transport.REST) {
1037-
responseStream =
1038-
new ResponseStream(
1039-
new ResponseStreamIteratorWithHistory(
1040-
vertexAi
1041-
.getPredictionServiceRestClient()
1042-
.streamGenerateContentCallable()
1043-
.call(request)
1044-
.iterator()));
1034+
return new ResponseStream(
1035+
new ResponseStreamIteratorWithHistory(
1036+
vertexAi
1037+
.getPredictionServiceRestClient()
1038+
.streamGenerateContentCallable()
1039+
.call(request)
1040+
.iterator()));
10451041
} else {
1046-
responseStream =
1047-
new ResponseStream(
1048-
new ResponseStreamIteratorWithHistory(
1049-
vertexAi
1050-
.getPredictionServiceClient()
1051-
.streamGenerateContentCallable()
1052-
.call(request)
1053-
.iterator()));
1042+
return new ResponseStream(
1043+
new ResponseStreamIteratorWithHistory(
1044+
vertexAi
1045+
.getPredictionServiceClient()
1046+
.streamGenerateContentCallable()
1047+
.call(request)
1048+
.iterator()));
10541049
}
1055-
return responseStream;
10561050
}
10571051

10581052
/**

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import com.google.cloud.vertexai.api.Citation;
2323
import com.google.cloud.vertexai.api.CitationMetadata;
2424
import com.google.cloud.vertexai.api.Content;
25+
import com.google.cloud.vertexai.api.FunctionCall;
2526
import com.google.cloud.vertexai.api.GenerateContentResponse;
2627
import com.google.cloud.vertexai.api.Part;
28+
import com.google.common.collect.ImmutableList;
2729
import java.util.ArrayList;
2830
import java.util.HashMap;
2931
import java.util.List;
@@ -33,20 +35,15 @@
3335
public class ResponseHandler {
3436

3537
/**
36-
* Get the text message in a GenerateContentResponse.
38+
* Gets the text message in a GenerateContentResponse.
3739
*
3840
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
3941
* @return a String that aggregates all the text parts in the response
4042
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
4143
* response is blocked by safety reason or unauthorized citations
4244
*/
4345
public static String getText(GenerateContentResponse response) {
44-
FinishReason finishReason = getFinishReason(response);
45-
if (finishReason == FinishReason.SAFETY) {
46-
throw new IllegalArgumentException("The response is blocked due to safety reason.");
47-
} else if (finishReason == FinishReason.RECITATION) {
48-
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
49-
}
46+
checkFinishReason(getFinishReason(response));
5047

5148
String text = "";
5249
List<Part> parts = response.getCandidates(0).getContent().getPartsList();
@@ -58,26 +55,40 @@ public static String getText(GenerateContentResponse response) {
5855
}
5956

6057
/**
61-
* Get the content in a GenerateContentResponse.
58+
* Gets the list of function calls in a GenerateContentResponse.
59+
*
60+
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
61+
* @return a list of {@link com.google.cloud.vertexai.api.FunctionCall} in the response
62+
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
63+
* response is blocked by safety reason or unauthorized citations
64+
*/
65+
public static ImmutableList<FunctionCall> getFunctionCalls(GenerateContentResponse response) {
66+
checkFinishReason(getFinishReason(response));
67+
if (response.getCandidatesCount() == 0) {
68+
return ImmutableList.of();
69+
}
70+
return response.getCandidates(0).getContent().getPartsList().stream()
71+
.filter((part) -> part.hasFunctionCall())
72+
.map((part) -> part.getFunctionCall())
73+
.collect(ImmutableList.toImmutableList());
74+
}
75+
76+
/**
77+
* Gets the content in a GenerateContentResponse.
6278
*
6379
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
6480
* @return the {@link com.google.cloud.vertexai.api.Content} in the response
6581
* @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the
6682
* response is blocked by safety reason or unauthorized citations
6783
*/
6884
public static Content getContent(GenerateContentResponse response) {
69-
FinishReason finishReason = getFinishReason(response);
70-
if (finishReason == FinishReason.SAFETY) {
71-
throw new IllegalArgumentException("The response is blocked due to safety reason.");
72-
} else if (finishReason == FinishReason.RECITATION) {
73-
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
74-
}
85+
checkFinishReason(getFinishReason(response));
7586

7687
return response.getCandidates(0).getContent();
7788
}
7889

7990
/**
80-
* Get the finish reason in a GenerateContentResponse.
91+
* Gets the finish reason in a GenerateContentResponse.
8192
*
8293
* @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance
8394
* @return the {@link com.google.cloud.vertexai.api.FinishReason} in the response
@@ -93,7 +104,7 @@ public static FinishReason getFinishReason(GenerateContentResponse response) {
93104
return response.getCandidates(0).getFinishReason();
94105
}
95106

96-
/** Aggregate a stream of responses into a single GenerateContentResponse. */
107+
/** Aggregates a stream of responses into a single GenerateContentResponse. */
97108
static GenerateContentResponse aggregateStreamIntoResponse(
98109
ResponseStream<GenerateContentResponse> responseStream) {
99110
GenerateContentResponse res = GenerateContentResponse.getDefaultInstance();
@@ -170,4 +181,12 @@ static GenerateContentResponse aggregateStreamIntoResponse(
170181

171182
return res;
172183
}
184+
185+
private static void checkFinishReason(FinishReason finishReason) {
186+
if (finishReason == FinishReason.SAFETY) {
187+
throw new IllegalArgumentException("The response is blocked due to safety reason.");
188+
} else if (finishReason == FinishReason.RECITATION) {
189+
throw new IllegalArgumentException("The response is blocked due to unauthorized citations.");
190+
}
191+
}
173192
}

java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
import com.google.cloud.vertexai.api.Citation;
2626
import com.google.cloud.vertexai.api.CitationMetadata;
2727
import com.google.cloud.vertexai.api.Content;
28+
import com.google.cloud.vertexai.api.FunctionCall;
2829
import com.google.cloud.vertexai.api.GenerateContentResponse;
2930
import com.google.cloud.vertexai.api.Part;
31+
import com.google.common.collect.ImmutableList;
3032
import java.util.Arrays;
3133
import java.util.Iterator;
3234
import org.junit.Rule;
@@ -47,6 +49,13 @@ public final class ResponseHandlerTest {
4749
.addParts(Part.newBuilder().setText(TEXT_1))
4850
.addParts(Part.newBuilder().setText(TEXT_2))
4951
.build();
52+
private static final Content CONTENT_WITH_FNCTION_CALL =
53+
Content.newBuilder()
54+
.addParts(Part.newBuilder().setText(TEXT_1))
55+
.addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance()))
56+
.addParts(Part.newBuilder().setText(TEXT_2))
57+
.addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance()))
58+
.build();
5059
private static final Citation CITATION_1 =
5160
Citation.newBuilder().setUri("gs://citation1").setStartIndex(1).setEndIndex(2).build();
5261
private static final Citation CITATION_2 =
@@ -61,10 +70,14 @@ public final class ResponseHandlerTest {
6170
.setContent(CONTENT)
6271
.setCitationMetadata(CitationMetadata.newBuilder().addCitations(CITATION_2))
6372
.build();
73+
private static final Candidate CANDIDATE_3 =
74+
Candidate.newBuilder().setContent(CONTENT_WITH_FNCTION_CALL).build();
6475
private static final GenerateContentResponse RESPONSE_1 =
6576
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_1).build();
6677
private static final GenerateContentResponse RESPONSE_2 =
6778
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_2).build();
79+
private static final GenerateContentResponse RESPONSE_3 =
80+
GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_3).build();
6881
private static final GenerateContentResponse INVALID_RESPONSE =
6982
GenerateContentResponse.newBuilder()
7083
.addCandidates(CANDIDATE_1)
@@ -94,6 +107,28 @@ public void testGetTextFromInvalidResponse() {
94107
INVALID_RESPONSE.getCandidatesCount()));
95108
}
96109

110+
@Test
111+
public void testGetFunctionCallsFromResponse() {
112+
ImmutableList<FunctionCall> functionCalls = ResponseHandler.getFunctionCalls(RESPONSE_3);
113+
assertThat(functionCalls.size()).isEqualTo(2);
114+
assertThat(functionCalls.get(0)).isEqualTo(FunctionCall.getDefaultInstance());
115+
assertThat(functionCalls.get(1)).isEqualTo(FunctionCall.getDefaultInstance());
116+
}
117+
118+
@Test
119+
public void testGetFunctionCallsFromInvalidResponse() {
120+
IllegalArgumentException thrown =
121+
assertThrows(
122+
IllegalArgumentException.class,
123+
() -> ResponseHandler.getFunctionCalls(INVALID_RESPONSE));
124+
assertThat(thrown)
125+
.hasMessageThat()
126+
.isEqualTo(
127+
String.format(
128+
"This response should have exactly 1 candidate, but it has %s.",
129+
INVALID_RESPONSE.getCandidatesCount()));
130+
}
131+
97132
@Test
98133
public void testGetContentFromResponse() {
99134
Content content = ResponseHandler.getContent(RESPONSE_1);

0 commit comments

Comments
 (0)