Skip to content

Commit 99d202f

Browse files
Refactoring stop to be a list of strings
1 parent 10ac1ae commit 99d202f

File tree

3 files changed

+70
-73
lines changed

3 files changed

+70
-73
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 64 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public record UnifiedCompletionRequest(
3333
List<Message> messages,
3434
@Nullable String model,
3535
@Nullable Long maxCompletionTokens,
36-
@Nullable Stop stop,
36+
@Nullable List<String> stop,
3737
@Nullable Float temperature,
3838
@Nullable ToolChoice toolChoice,
3939
@Nullable List<Tool> tools,
@@ -49,7 +49,7 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C
4949
(List<Message>) args[0],
5050
(String) args[1],
5151
(Long) args[2],
52-
(Stop) args[3],
52+
(List<String>) args[3],
5353
(Float) args[4],
5454
(ToolChoice) args[5],
5555
(List<Tool>) args[6],
@@ -61,7 +61,9 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C
6161
PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages"));
6262
PARSER.declareString(optionalConstructorArg(), new ParseField("model"));
6363
PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens"));
64-
PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE_ARRAY);
64+
// PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"),
65+
// ObjectParser.ValueType.VALUE_ARRAY);
66+
PARSER.declareStringArray(optionalConstructorArg(), new ParseField("stop"));
6567
PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature"));
6668
PARSER.declareField(
6769
optionalConstructorArg(),
@@ -78,9 +80,9 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
7880
new NamedWriteableRegistry.Entry(Content.class, ContentObjects.NAME, ContentObjects::new),
7981
new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new),
8082
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new),
81-
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new),
82-
new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new),
83-
new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new)
83+
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new)
84+
// new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new),
85+
// new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new)
8486
);
8587
}
8688

@@ -93,7 +95,7 @@ public UnifiedCompletionRequest(StreamInput in) throws IOException {
9395
in.readCollectionAsImmutableList(Message::new),
9496
in.readOptionalString(),
9597
in.readOptionalVLong(),
96-
in.readOptionalNamedWriteable(Stop.class),
98+
in.readOptionalStringCollectionAsList(),
9799
in.readOptionalFloat(),
98100
in.readOptionalNamedWriteable(ToolChoice.class),
99101
in.readOptionalCollectionAsList(Tool::new),
@@ -106,7 +108,7 @@ public void writeTo(StreamOutput out) throws IOException {
106108
out.writeCollection(messages);
107109
out.writeOptionalString(model);
108110
out.writeOptionalVLong(maxCompletionTokens);
109-
out.writeOptionalNamedWriteable(stop);
111+
out.writeOptionalStringCollection(stop);
110112
out.writeOptionalFloat(temperature);
111113
out.writeOptionalNamedWriteable(toolChoice);
112114
out.writeOptionalCollection(tools);
@@ -279,60 +281,60 @@ public void writeTo(StreamOutput out) throws IOException {
279281
}
280282
}
281283

282-
private static Stop parseStop(XContentParser parser) throws IOException {
283-
var token = parser.currentToken();
284-
if (token == XContentParser.Token.START_ARRAY) {
285-
var parsedStopValues = XContentParserUtils.parseList(parser, XContentParser::text);
286-
return new StopValues(parsedStopValues);
287-
} else if (token == XContentParser.Token.VALUE_STRING) {
288-
return StopString.of(parser);
289-
}
290-
291-
throw new XContentParseException("Unsupported token [" + token + "]");
292-
}
293-
294-
public sealed interface Stop extends NamedWriteable permits StopString, StopValues {}
295-
296-
public record StopString(String value) implements Stop, NamedWriteable {
297-
public static final String NAME = "stop_string";
298-
299-
public static StopString of(XContentParser parser) throws IOException {
300-
var content = parser.text();
301-
return new StopString(content);
302-
}
303-
304-
public StopString(StreamInput in) throws IOException {
305-
this(in.readString());
306-
}
307-
308-
@Override
309-
public void writeTo(StreamOutput out) throws IOException {
310-
out.writeString(value);
311-
}
312-
313-
@Override
314-
public String getWriteableName() {
315-
return NAME;
316-
}
317-
}
318-
319-
public record StopValues(List<String> values) implements Stop, NamedWriteable {
320-
public static final String NAME = "stop_values";
321-
322-
public StopValues(StreamInput in) throws IOException {
323-
this(in.readStringCollectionAsImmutableList());
324-
}
325-
326-
@Override
327-
public void writeTo(StreamOutput out) throws IOException {
328-
out.writeStringCollection(values);
329-
}
330-
331-
@Override
332-
public String getWriteableName() {
333-
return NAME;
334-
}
335-
}
284+
// private static Stop parseStop(XContentParser parser) throws IOException {
285+
// var token = parser.currentToken();
286+
// if (token == XContentParser.Token.START_ARRAY) {
287+
// var parsedStopValues = XContentParserUtils.parseList(parser, XContentParser::text);
288+
// return new StopValues(parsedStopValues);
289+
// } else if (token == XContentParser.Token.VALUE_STRING) {
290+
// return StopString.of(parser);
291+
// }
292+
//
293+
// throw new XContentParseException("Unsupported token [" + token + "]");
294+
// }
295+
296+
// public sealed interface Stop extends NamedWriteable permits StopString, StopValues {}
297+
//
298+
// public record StopString(String value) implements Stop, NamedWriteable {
299+
// public static final String NAME = "stop_string";
300+
//
301+
// public static StopString of(XContentParser parser) throws IOException {
302+
// var content = parser.text();
303+
// return new StopString(content);
304+
// }
305+
//
306+
// public StopString(StreamInput in) throws IOException {
307+
// this(in.readString());
308+
// }
309+
//
310+
// @Override
311+
// public void writeTo(StreamOutput out) throws IOException {
312+
// out.writeString(value);
313+
// }
314+
//
315+
// @Override
316+
// public String getWriteableName() {
317+
// return NAME;
318+
// }
319+
// }
320+
//
321+
// public record StopValues(List<String> values) implements Stop, NamedWriteable {
322+
// public static final String NAME = "stop_values";
323+
//
324+
// public StopValues(StreamInput in) throws IOException {
325+
// this(in.readStringCollectionAsImmutableList());
326+
// }
327+
//
328+
// @Override
329+
// public void writeTo(StreamOutput out) throws IOException {
330+
// out.writeStringCollection(values);
331+
// }
332+
//
333+
// @Override
334+
// public String getWriteableName() {
335+
// return NAME;
336+
// }
337+
// }
336338

337339
private static ToolChoice parseToolChoice(XContentParser parser) throws IOException {
338340
var token = parser.currentToken();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public void testParseAllFields() throws IOException {
9696
),
9797
"gpt-4o",
9898
100L,
99-
new UnifiedCompletionRequest.StopValues(List.of("stop")),
99+
List.of("stop"),
100100
0.1F,
101101
new UnifiedCompletionRequest.ToolChoiceObject(
102102
"function",
@@ -161,7 +161,7 @@ public void testParsing() throws IOException {
161161
),
162162
"gpt-4o",
163163
null,
164-
new UnifiedCompletionRequest.StopString("none"),
164+
List.of("none"),
165165
null,
166166
new UnifiedCompletionRequest.ToolChoiceString("auto"),
167167
List.of(
@@ -227,14 +227,12 @@ public static UnifiedCompletionRequest.ToolCall.FunctionField randomToolCallFunc
227227
return new UnifiedCompletionRequest.ToolCall.FunctionField(randomAlphaOfLength(10), randomAlphaOfLength(10));
228228
}
229229

230-
public static UnifiedCompletionRequest.Stop randomStopOrNull() {
230+
public static List<String> randomStopOrNull() {
231231
return randomBoolean() ? randomStop() : null;
232232
}
233233

234-
public static UnifiedCompletionRequest.Stop randomStop() {
235-
return randomBoolean()
236-
? new UnifiedCompletionRequest.StopString(randomAlphaOfLength(10))
237-
: new UnifiedCompletionRequest.StopValues(randomList(5, () -> randomAlphaOfLength(10)));
234+
public static List<String> randomStop() {
235+
return randomList(5, () -> randomAlphaOfLength(10));
238236
}
239237

240238
public static UnifiedCompletionRequest.ToolChoice randomToolChoiceOrNull() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
115115
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
116116

117117
if (unifiedRequest.stop() != null) {
118-
switch (unifiedRequest.stop()) {
119-
case UnifiedCompletionRequest.StopString stopString -> builder.field(STOP_FIELD, stopString.value());
120-
case UnifiedCompletionRequest.StopValues stopValues -> builder.field(STOP_FIELD, stopValues.values());
121-
}
118+
builder.field(STOP_FIELD, unifiedRequest.stop());
122119
}
123120
if (unifiedRequest.temperature() != null) {
124121
builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature());

0 commit comments

Comments
 (0)