Skip to content

Commit eba5fce

Browse files
Custom service fixes
1 parent 14a5383 commit eba5fce

File tree

4 files changed

+50
-59
lines changed

4 files changed

+50
-59
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ static TransportVersion def(int id) {
174174
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
175175
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
176176
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
177+
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_19 = def(8_841_0_30);
177178
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
178179
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
179180
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -252,7 +253,7 @@ static TransportVersion def(int id) {
252253
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
253254
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_0_00);
254255
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
255-
256+
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_076_0_00);
256257
/*
257258
* STOP! READ THIS FIRST! No, really,
258259
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xcontent.ToXContent;
3535
import org.elasticsearch.xcontent.ToXContentObject;
3636
import org.elasticsearch.xcontent.XContentBuilder;
37+
import org.elasticsearch.xpack.core.inference.DequeUtils;
3738
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3839
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
3940
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
@@ -256,37 +257,24 @@ public void cancel() {}
256257
"object": "chat.completion.chunk"
257258
}
258259
*/
259-
private InferenceServiceResults.Result unifiedCompletionChunk(String delta) {
260-
return new InferenceServiceResults.Result() {
261-
@Override
262-
public String getWriteableName() {
263-
return "test_unifiedCompletionChunk";
264-
}
265-
266-
@Override
267-
public void writeTo(StreamOutput out) throws IOException {
268-
out.writeString(delta);
269-
}
270-
271-
@Override
272-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
273-
return ChunkedToXContentHelper.chunk(
274-
(b, p) -> b.startObject()
275-
.field("id", "id")
276-
.startArray("choices")
277-
.startObject()
278-
.startObject("delta")
279-
.field("content", delta)
280-
.endObject()
281-
.field("index", 0)
282-
.endObject()
283-
.endArray()
284-
.field("model", "gpt-4o-2024-08-06")
285-
.field("object", "chat.completion.chunk")
286-
.endObject()
287-
);
288-
}
289-
};
260+
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
261+
return new StreamingUnifiedChatCompletionResults.Results(
262+
DequeUtils.of(
263+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
264+
"id",
265+
List.of(
266+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
267+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
268+
null,
269+
0
270+
)
271+
),
272+
"gpt-4o-2024-08-06",
273+
"chat.completion.chunk",
274+
null
275+
)
276+
)
277+
);
290278
}
291279

292280
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121

2222
import static org.elasticsearch.rest.RestRequest.Method.PUT;
23+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
2324
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
2425
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
2526
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
@@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
4950
taskType = TaskType.ANY; // task type must be defined in the body
5051
}
5152

53+
var inferTimeout = parseTimeout(restRequest);
5254
var content = restRequest.requiredContent();
53-
var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
55+
var request = new PutInferenceModelAction.Request(
56+
taskType,
57+
inferenceEntityId,
58+
content,
59+
restRequest.getXContentType(),
60+
inferTimeout
61+
);
5462
return channel -> client.execute(
5563
PutInferenceModelAction.INSTANCE,
5664
request,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
package org.elasticsearch.xpack.inference.action;
99

10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1012
import org.elasticsearch.common.io.stream.Writeable;
1113
import org.elasticsearch.inference.TaskType;
12-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1314
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1416
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
17+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1518

16-
public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
19+
public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
1720
@Override
1821
protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
1922
return PutInferenceModelAction.Request::new;
@@ -25,38 +28,29 @@ protected PutInferenceModelAction.Request createTestInstance() {
2528
randomFrom(TaskType.values()),
2629
randomAlphaOfLength(6),
2730
randomBytesReference(50),
28-
randomFrom(XContentType.values())
31+
randomFrom(XContentType.values()),
32+
randomTimeValue()
2933
);
3034
}
3135

3236
@Override
3337
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
34-
return switch (randomIntBetween(0, 3)) {
35-
case 0 -> new PutInferenceModelAction.Request(
36-
TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
37-
instance.getInferenceEntityId(),
38-
instance.getContent(),
39-
instance.getContentType()
40-
);
41-
case 1 -> new PutInferenceModelAction.Request(
42-
instance.getTaskType(),
43-
instance.getInferenceEntityId() + "foo",
44-
instance.getContent(),
45-
instance.getContentType()
46-
);
47-
case 2 -> new PutInferenceModelAction.Request(
48-
instance.getTaskType(),
49-
instance.getInferenceEntityId(),
50-
randomBytesReference(instance.getContent().length() + 1),
51-
instance.getContentType()
52-
);
53-
case 3 -> new PutInferenceModelAction.Request(
38+
return randomValueOtherThan(instance, this::createTestInstance);
39+
}
40+
41+
@Override
42+
protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
43+
if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
44+
|| version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
45+
return instance;
46+
} else {
47+
return new PutInferenceModelAction.Request(
5448
instance.getTaskType(),
5549
instance.getInferenceEntityId(),
5650
instance.getContent(),
57-
XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
51+
instance.getContentType(),
52+
InferenceAction.Request.DEFAULT_TIMEOUT
5853
);
59-
default -> throw new IllegalStateException();
60-
};
54+
}
6155
}
6256
}

0 commit comments

Comments
 (0)