Skip to content

Commit cb440e1

Browse files
Fixing parsing logic
1 parent 81a05b7 commit cb440e1

File tree

14 files changed

+84
-76
lines changed

14 files changed

+84
-76
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.inference;
99

1010
import org.elasticsearch.common.io.stream.NamedWriteable;
11+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1112
import org.elasticsearch.common.io.stream.StreamInput;
1213
import org.elasticsearch.common.io.stream.StreamOutput;
1314
import org.elasticsearch.common.io.stream.Writeable;
@@ -52,11 +53,11 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C
5253
(Long) args[2],
5354
(Integer) args[3],
5455
(Stop) args[4],
55-
(Float) args[6],
56-
(ToolChoice) args[7],
57-
(List<Tool>) args[8],
58-
(Float) args[9],
59-
(String) args[10]
56+
(Float) args[5],
57+
(ToolChoice) args[6],
58+
(List<Tool>) args[7],
59+
(Float) args[8],
60+
(String) args[9]
6061
)
6162
);
6263

@@ -78,6 +79,17 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C
7879
PARSER.declareString(optionalConstructorArg(), new ParseField("user"));
7980
}
8081

82+
public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
83+
return List.of(
84+
new NamedWriteableRegistry.Entry(Content.class, ContentObjects.NAME, ContentObjects::new),
85+
new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new),
86+
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new),
87+
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new),
88+
new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new),
89+
new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new)
90+
);
91+
}
92+
8193
public UnifiedCompletionRequest(StreamInput in) throws IOException {
8294
this(
8395
in.readCollectionAsImmutableList(Message::new),
@@ -157,7 +169,7 @@ public void writeTo(StreamOutput out) throws IOException {
157169
}
158170
}
159171

160-
public record ContentObjects(List<ContentObject> contentObjects) implements Content, Writeable {
172+
public record ContentObjects(List<ContentObject> contentObjects) implements Content, NamedWriteable {
161173

162174
public static final String NAME = "content_objects";
163175

test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,10 +1213,18 @@ public static Long randomNullOrLong() {
12131213
return randomBoolean() ? null : randomLong();
12141214
}
12151215

1216+
public static Long randomNullOrPositiveLong() {
1217+
return randomBoolean() ? null : randomLongBetween(0L, Long.MAX_VALUE);
1218+
}
1219+
12161220
public static Integer randomNullOrInt() {
12171221
return randomBoolean() ? null : randomInt();
12181222
}
12191223

1224+
public static Integer randomNullOrPositiveInt() {
1225+
return randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE);
1226+
}
1227+
12201228
public static Float randomNullOrFloat() {
12211229
return randomBoolean() ? null : randomFloat();
12221230
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ protected InferenceAction.Request createTestInstance() {
4747
randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
4848
randomFrom(InputType.values()),
4949
TimeValue.timeValueMillis(randomLongBetween(1, 2048)),
50-
false,
5150
false
5251
);
5352
}
@@ -83,7 +82,6 @@ public void testValidation_TextEmbedding() {
8382
null,
8483
null,
8584
null,
86-
false,
8785
false
8886
);
8987
ActionRequestValidationException e = request.validate();
@@ -99,7 +97,6 @@ public void testValidation_Rerank() {
9997
null,
10098
null,
10199
null,
102-
false,
103100
false
104101
);
105102
ActionRequestValidationException e = request.validate();
@@ -115,7 +112,6 @@ public void testValidation_TextEmbedding_Null() {
115112
null,
116113
null,
117114
null,
118-
false,
119115
false
120116
);
121117
ActionRequestValidationException inputNullError = inputNullRequest.validate();
@@ -132,7 +128,6 @@ public void testValidation_TextEmbedding_Empty() {
132128
null,
133129
null,
134130
null,
135-
false,
136131
false
137132
);
138133
ActionRequestValidationException inputEmptyError = inputEmptyRequest.validate();
@@ -149,7 +144,6 @@ public void testValidation_Rerank_Null() {
149144
null,
150145
null,
151146
null,
152-
false,
153147
false
154148
);
155149
ActionRequestValidationException queryNullError = queryNullRequest.validate();
@@ -166,7 +160,6 @@ public void testValidation_Rerank_Empty() {
166160
null,
167161
null,
168162
null,
169-
false,
170163
false
171164
);
172165
ActionRequestValidationException queryEmptyError = queryEmptyRequest.validate();
@@ -200,7 +193,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
200193
instance.getTaskSettings(),
201194
instance.getInputType(),
202195
instance.getInferenceTimeout(),
203-
false,
204196
false
205197
);
206198
}
@@ -212,7 +204,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
212204
instance.getTaskSettings(),
213205
instance.getInputType(),
214206
instance.getInferenceTimeout(),
215-
false,
216207
false
217208
);
218209
case 2 -> {
@@ -226,7 +217,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
226217
instance.getTaskSettings(),
227218
instance.getInputType(),
228219
instance.getInferenceTimeout(),
229-
false,
230220
false
231221
);
232222
}
@@ -246,7 +236,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
246236
taskSettings,
247237
instance.getInputType(),
248238
instance.getInferenceTimeout(),
249-
false,
250239
false
251240
);
252241
}
@@ -260,7 +249,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
260249
instance.getTaskSettings(),
261250
nextInputType,
262251
instance.getInferenceTimeout(),
263-
false,
264252
false
265253
);
266254
}
@@ -272,7 +260,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
272260
instance.getTaskSettings(),
273261
instance.getInputType(),
274262
instance.getInferenceTimeout(),
275-
false,
276263
false
277264
);
278265
case 6 -> {
@@ -289,7 +276,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc
289276
instance.getTaskSettings(),
290277
instance.getInputType(),
291278
TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()),
292-
false,
293279
false
294280
);
295281
}
@@ -308,7 +294,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
308294
instance.getTaskSettings(),
309295
InputType.UNSPECIFIED,
310296
InferenceAction.Request.DEFAULT_TIMEOUT,
311-
false,
312297
false
313298
);
314299
} else if (version.before(TransportVersions.V_8_13_0)) {
@@ -320,7 +305,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
320305
instance.getTaskSettings(),
321306
InputType.UNSPECIFIED,
322307
InferenceAction.Request.DEFAULT_TIMEOUT,
323-
false,
324308
false
325309
);
326310
} else if (version.before(TransportVersions.V_8_13_0)
@@ -335,7 +319,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
335319
instance.getTaskSettings(),
336320
InputType.INGEST,
337321
InferenceAction.Request.DEFAULT_TIMEOUT,
338-
false,
339322
false
340323
);
341324
} else if (version.before(TransportVersions.V_8_13_0)
@@ -348,7 +331,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
348331
instance.getTaskSettings(),
349332
InputType.UNSPECIFIED,
350333
InferenceAction.Request.DEFAULT_TIMEOUT,
351-
false,
352334
false
353335
);
354336
} else if (version.before(TransportVersions.V_8_14_0)) {
@@ -360,7 +342,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque
360342
instance.getTaskSettings(),
361343
instance.getInputType(),
362344
InferenceAction.Request.DEFAULT_TIMEOUT,
363-
false,
364345
false
365346
);
366347
}
@@ -378,7 +359,6 @@ public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOExceptio
378359
Map.of(),
379360
InputType.UNSPECIFIED,
380361
InferenceAction.Request.DEFAULT_TIMEOUT,
381-
false,
382362
false
383363
),
384364
TransportVersions.V_8_13_0
@@ -394,7 +374,6 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn
394374
Map.of(),
395375
InputType.INGEST,
396376
InferenceAction.Request.DEFAULT_TIMEOUT,
397-
false,
398377
false
399378
);
400379

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.core.inference.action;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1112
import org.elasticsearch.common.io.stream.Writeable;
1213
import org.elasticsearch.inference.UnifiedCompletionRequest;
1314
import org.elasticsearch.xcontent.json.JsonXContent;
@@ -51,7 +52,6 @@ public void testParseAllFields() throws IOException {
5152
"max_completion_tokens": 100,
5253
"n": 1,
5354
"stop": ["stop"],
54-
"stream": true,
5555
"temperature": 0.1,
5656
"tools": [
5757
{
@@ -192,8 +192,8 @@ public static UnifiedCompletionRequest randomUnifiedCompletionRequest() {
192192
return new UnifiedCompletionRequest(
193193
randomList(5, UnifiedCompletionRequestTests::randomMessage),
194194
randomNullOrAlphaOfLength(10),
195-
randomNullOrLong(),
196-
randomNullOrInt(),
195+
randomNullOrPositiveLong(),
196+
randomNullOrPositiveInt(),
197197
randomNullOrStop(),
198198
randomNullOrFloat(),
199199
randomNullOrToolChoice(),
@@ -287,4 +287,12 @@ protected UnifiedCompletionRequest createTestInstance() {
287287
protected UnifiedCompletionRequest mutateInstance(UnifiedCompletionRequest instance) throws IOException {
288288
return randomValueOtherThan(instance, this::createTestInstance);
289289
}
290+
291+
@Override
292+
protected NamedWriteableRegistry getNamedWriteableRegistry() {
293+
// List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
294+
// entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
295+
// entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables());
296+
return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables());
297+
}
290298
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
8484
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
8585
import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction;
86+
import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction;
8687
import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction;
8788
import org.elasticsearch.xpack.inference.services.ServiceComponents;
8889
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService;
@@ -152,9 +153,9 @@ public InferencePlugin(Settings settings) {
152153

153154
@Override
154155
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
155-
return List.of(
156+
var availableActions = List.of(
156157
new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class),
157-
new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class),
158+
158159
new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class),
159160
new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class),
160161
new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class),
@@ -163,6 +164,13 @@ public InferencePlugin(Settings settings) {
163164
new ActionHandler<>(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class),
164165
new ActionHandler<>(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class)
165166
);
167+
168+
List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> conditionalActions =
169+
UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled()
170+
? List.of(new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class))
171+
: List.of();
172+
173+
return Stream.concat(availableActions.stream(), conditionalActions.stream()).toList();
166174
}
167175

168176
@Override
@@ -177,7 +185,7 @@ public List<RestHandler> getRestHandlers(
177185
Supplier<DiscoveryNodes> nodesInCluster,
178186
Predicate<NodeFeature> clusterSupportsFeature
179187
) {
180-
return List.of(
188+
var availableRestActions = List.of(
181189
new RestInferenceAction(),
182190
new RestStreamInferenceAction(),
183191
new RestGetInferenceModelAction(),
@@ -187,6 +195,11 @@ public List<RestHandler> getRestHandlers(
187195
new RestGetInferenceDiagnosticsAction(),
188196
new RestGetInferenceServicesAction()
189197
);
198+
List<RestHandler> conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled()
199+
? List.of(new RestUnifiedCompletionInferenceAction())
200+
: List.of();
201+
202+
return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList();
190203
}
191204

192205
@Override
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.rest;
8+
package org.elasticsearch.xpack.inference;
99

1010
import org.elasticsearch.common.util.FeatureFlag;
1111

1212
/**
1313
* Unified Completion feature flag. When the feature is complete, this flag will be removed.
14-
* Enable feature via JVM option: `-Des.unified_feature_flag_enabled=true`.
14+
* Enable feature via JVM option: `-Des.inference_unified_feature_flag_enabled=true`.
1515
*/
1616
public class UnifiedCompletionFeature {
17-
public static final FeatureFlag UNIFIED_COMPLETION_FEATURE_FLAG = new FeatureFlag("unified");
17+
public static final FeatureFlag UNIFIED_COMPLETION_FEATURE_FLAG = new FeatureFlag("inference_unified");
1818

1919
private UnifiedCompletionFeature() {}
2020
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java

Lines changed: 0 additions & 27 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import java.util.List;
1616
import java.util.Objects;
1717

18+
// TODO remove this
1819
public class OpenAiChatCompletionRequestEntity implements ToXContentObject {
1920

2021
private static final String MESSAGES_FIELD = "messages";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public HttpRequest createHttpRequest() {
4444
HttpPost httpPost = new HttpPost(account.uri());
4545

4646
ByteArrayEntity byteEntity = new ByteArrayEntity(
47-
Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput)).getBytes(StandardCharsets.UTF_8)
47+
Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8)
4848
);
4949
httpPost.setEntity(byteEntity);
5050

0 commit comments

Comments
 (0)