Skip to content

Commit 6cf0c0b

Browse files
committed
Added unit tests
1 parent ce6d45f commit 6cf0c0b

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai;
9+
10+
import org.elasticsearch.common.bytes.BytesReference;
11+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
12+
import org.elasticsearch.common.xcontent.XContentHelper;
13+
import org.elasticsearch.core.Strings;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentParseException;
17+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
18+
19+
import java.io.IOException;
20+
import java.util.ArrayDeque;
21+
22+
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
23+
import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onError;
24+
import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onNext;
25+
import static org.hamcrest.Matchers.equalTo;
26+
import static org.hamcrest.Matchers.instanceOf;
27+
28+
public class GoogleVertexAiStreamingProcessorTests extends ESTestCase {
29+
30+
public void testParseVertexAiResponse() throws IOException {
31+
var item = new ArrayDeque<ServerSentEvent>();
32+
item.offer(new ServerSentEvent(vertexAiJsonResponse("test", true)));
33+
34+
var response = onNext(new GoogleVertexAiStreamingProcessor(), item);
35+
var json = toJsonString(response);
36+
37+
assertThat(json, equalTo("""
38+
{"completion":[{"delta":"test"}]}"""));
39+
}
40+
41+
public void testParseVertexAiResponseMultiple() throws IOException {
42+
var item = new ArrayDeque<ServerSentEvent>();
43+
item.offer(new ServerSentEvent(vertexAiJsonResponse("hello", false)));
44+
45+
item.offer(new ServerSentEvent(vertexAiJsonResponse("world", true)));
46+
47+
var response = onNext(new GoogleVertexAiStreamingProcessor(), item);
48+
var json = toJsonString(response);
49+
50+
assertThat(json, equalTo("""
51+
{"completion":[{"delta":"hello"},{"delta":"world"}]}"""));
52+
}
53+
54+
public void testParseErrorCallsOnError() {
55+
var item = new ArrayDeque<ServerSentEvent>();
56+
item.offer(new ServerSentEvent("not json"));
57+
58+
var exception = onError(new GoogleVertexAiStreamingProcessor(), item);
59+
assertThat(exception, instanceOf(XContentParseException.class));
60+
}
61+
62+
private String vertexAiJsonResponse(String content, boolean includeFinishReason) {
63+
String finishReason = includeFinishReason ? "\"finishReason\": \"STOP\"," : "";
64+
65+
return Strings.format("""
66+
{
67+
"candidates": [
68+
{
69+
"content": {
70+
"role": "model",
71+
"parts": [
72+
{
73+
"text": "%s"
74+
}
75+
]
76+
},
77+
%s
78+
"avgLogprobs": -0.19326641248620074
79+
}
80+
],
81+
"usageMetadata": {
82+
"promptTokenCount": 71,
83+
"candidatesTokenCount": 23,
84+
"totalTokenCount": 94,
85+
"trafficType": "ON_DEMAND",
86+
"promptTokensDetails": [
87+
{
88+
"modality": "TEXT",
89+
"tokenCount": 71
90+
}
91+
],
92+
"candidatesTokensDetails": [
93+
{
94+
"modality": "TEXT",
95+
"tokenCount": 23
96+
}
97+
]
98+
},
99+
"modelVersion": "gemini-2.0-flash-001",
100+
"createTime": "2025-05-28T15:08:20.049493Z",
101+
"responseId": "5CY3aNWCA6mm4_UPr-zduAE"
102+
}
103+
""", content, finishReason);
104+
}
105+
106+
private String toJsonString(ChunkedToXContent chunkedToXContent) throws IOException {
107+
try (var builder = XContentFactory.jsonBuilder()) {
108+
chunkedToXContent.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
109+
try {
110+
xContent.toXContent(builder, EMPTY_PARAMS);
111+
} catch (IOException e) {
112+
logger.error(e.getMessage(), e);
113+
fail(e.getMessage());
114+
}
115+
});
116+
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
117+
}
118+
}
119+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiCompletionModelTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import java.util.HashMap;
1818
import java.util.Map;
1919

20+
import static org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils.GENERATE_CONTENT;
21+
import static org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils.STREAM_GENERATE_CONTENT;
22+
import static org.hamcrest.Matchers.containsString;
2023
import static org.hamcrest.Matchers.is;
2124

2225
public class GoogleVertexAiCompletionModelTests extends ESTestCase {
@@ -39,6 +42,15 @@ public void testCreateModel() throws URISyntaxException {
3942
assertThat(model.uri(), is(expectedUri));
4043
}
4144

45+
public void testUpdateUri() throws URISyntaxException {
46+
var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID);
47+
assertThat(model.uri().toString(), containsString(GENERATE_CONTENT));
48+
model.updateUri(true);
49+
assertThat(model.uri().toString(), containsString(STREAM_GENERATE_CONTENT));
50+
model.updateUri(false);
51+
assertThat(model.uri().toString(), containsString(GENERATE_CONTENT));
52+
}
53+
4254
private static GoogleVertexAiCompletionModel createCompletionModel(String projectId, String location, String modelId) {
4355
return new GoogleVertexAiCompletionModel(
4456
"google-vertex-ai-chat-test-id",

0 commit comments

Comments
 (0)