Skip to content

Commit acc864f

Browse files
committed
Refactored to constructorArg for mandatory args in GoogleVertexAiUnifiedStreamingProcessor
1 parent c05655f commit acc864f

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessor.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,17 @@ private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice
192192

193193
static {
194194
PARSER.declareObjectArray(
195-
ConstructingObjectParser.optionalConstructorArg(),
195+
ConstructingObjectParser.constructorArg(),
196196
(p, c) -> CandidateParser.parse(p),
197197
new ParseField(CANDIDATES_FIELD)
198198
);
199199
PARSER.declareObject(
200-
ConstructingObjectParser.optionalConstructorArg(),
200+
ConstructingObjectParser.constructorArg(),
201201
(p, c) -> UsageMetadataParser.parse(p),
202202
new ParseField(USAGE_METADATA_FIELD)
203203
);
204-
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(MODEL_VERSION_FIELD));
205-
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(RESPONSE_ID_FIELD));
204+
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_VERSION_FIELD));
205+
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(RESPONSE_ID_FIELD));
206206
}
207207

208208
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException {
@@ -224,7 +224,7 @@ private static class CandidateParser {
224224

225225
static {
226226
PARSER.declareObject(
227-
ConstructingObjectParser.optionalConstructorArg(),
227+
ConstructingObjectParser.constructorArg(),
228228
(p, c) -> ContentParser.parse(p),
229229
new ParseField(CONTENT_FIELD)
230230
);
@@ -248,9 +248,9 @@ private static class ContentParser {
248248
);
249249

250250
static {
251-
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD));
251+
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ROLE_FIELD));
252252
PARSER.declareObjectArray(
253-
ConstructingObjectParser.optionalConstructorArg(),
253+
ConstructingObjectParser.constructorArg(),
254254
(p, c) -> PartParser.parse(p),
255255
new ParseField(PARTS_FIELD)
256256
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedStreamingProcessorTests.java

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public void testJsonLiteral() {
8181
}
8282
}
8383

84-
public void testJsonLiteral_optionalTopLevelFieldsMissing() {
84+
public void testJsonLiteral_usageMetadataTokenCountMissing() {
8585
String json = """
8686
{
8787
"candidates" : [ {
@@ -90,7 +90,12 @@ public void testJsonLiteral_optionalTopLevelFieldsMissing() {
9090
"parts" : [ { "text" : "Hello" } ]
9191
},
9292
"finishReason": "STOP"
93-
} ]
93+
} ],
94+
"usageMetadata" : {
95+
"trafficType" : "ON_DEMAND"
96+
},
97+
"modelVersion": "gemini-2.0-flash-001",
98+
"responseId": "responseId"
9499
}
95100
""";
96101

@@ -101,12 +106,11 @@ public void testJsonLiteral_optionalTopLevelFieldsMissing() {
101106
try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) {
102107
var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser);
103108

104-
assertNull(chunk.id());
109+
assertEquals("responseId", chunk.id());
105110
assertEquals(1, chunk.choices().size());
106111
var choice = chunk.choices().getFirst();
107112
assertEquals("Hello", choice.delta().content());
108113
assertEquals("model", choice.delta().role());
109-
assertNull(chunk.model());
110114
assertEquals("STOP", choice.finishReason());
111115
assertEquals(0, choice.index());
112116
assertNull(choice.delta().toolCalls());
@@ -131,7 +135,14 @@ public void testJsonLiteral_functionCallArgsMissing() {
131135
]
132136
}
133137
} ],
134-
"responseId" : "resId789"
138+
"responseId" : "resId789",
139+
"modelVersion": "gemini-2.0-flash-00",
140+
"usageMetadata" : {
141+
"promptTokenCount": 10,
142+
"candidatesTokenCount": 20,
143+
"totalTokenCount": 30,
144+
"trafficType" : "ON_DEMAND"
145+
}
135146
}
136147
""";
137148
XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(
@@ -171,7 +182,14 @@ public void testJsonLiteral_multipleTextParts() {
171182
},
172183
"finishReason": "STOP"
173184
} ],
174-
"responseId" : "multiTextId"
185+
"responseId" : "multiTextId",
186+
"usageMetadata" : {
187+
"promptTokenCount": 10,
188+
"candidatesTokenCount": 20,
189+
"totalTokenCount": 30,
190+
"trafficType" : "ON_DEMAND"
191+
},
192+
"modelVersion": "gemini-2.0-flash-001"
175193
}
176194
""";
177195

@@ -192,8 +210,7 @@ public void testJsonLiteral_multipleTextParts() {
192210
assertEquals("STOP", choice.finishReason());
193211
assertEquals(0, choice.index());
194212
assertNull(choice.delta().toolCalls());
195-
assertNull(chunk.model());
196-
assertNull(chunk.usage());
213+
assertEquals("gemini-2.0-flash-001", chunk.model());
197214
} catch (IOException e) {
198215
fail("IOException during test: " + e.getMessage());
199216
}

0 commit comments

Comments
 (0)