Skip to content

Commit 35cd749

Browse files
jonathan-buttnerafoucret
authored andcommitted
[ML] Fixing bug with TransportPutModelAction listener and adding timeout to request (elastic#126805)
* Fixing bug with listener and adding timeout * Update docs/changelog/126805.yaml * Fixing tests * Fixing writeTo
1 parent 0bd5633 commit 35cd749

File tree

7 files changed

+84
-35
lines changed

7 files changed

+84
-35
lines changed

docs/changelog/126805.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126805
2+
summary: Adding timeout to request for creating inference endpoint
3+
area: Machine Learning
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ static TransportVersion def(int id) {
172172
public static final TransportVersion INTRODUCE_FAILURES_LIFECYCLE_BACKPORT_8_19 = def(8_841_0_25);
173173
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
174174
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
175+
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
175176
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
176177
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
177178
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -248,6 +249,7 @@ static TransportVersion def(int id) {
248249
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION = def(9_071_0_00);
249250
public static final TransportVersion FILE_SETTINGS_HEALTH_INFO = def(9_072_0_00);
250251
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
252+
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_00_0);
251253

252254
/*
253255
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.action;
99

10+
import org.elasticsearch.TransportVersions;
1011
import org.elasticsearch.action.ActionRequestValidationException;
1112
import org.elasticsearch.action.ActionResponse;
1213
import org.elasticsearch.action.ActionType;
@@ -15,6 +16,7 @@
1516
import org.elasticsearch.common.io.stream.StreamInput;
1617
import org.elasticsearch.common.io.stream.StreamOutput;
1718
import org.elasticsearch.common.xcontent.XContentHelper;
19+
import org.elasticsearch.core.TimeValue;
1820
import org.elasticsearch.inference.ModelConfigurations;
1921
import org.elasticsearch.inference.TaskType;
2022
import org.elasticsearch.xcontent.ToXContentObject;
@@ -41,13 +43,15 @@ public static class Request extends AcknowledgedRequest<Request> {
4143
private final String inferenceEntityId;
4244
private final BytesReference content;
4345
private final XContentType contentType;
46+
private final TimeValue timeout;
4447

45-
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
48+
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) {
4649
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
4750
this.taskType = taskType;
4851
this.inferenceEntityId = inferenceEntityId;
4952
this.content = content;
5053
this.contentType = contentType;
54+
this.timeout = timeout;
5155
}
5256

5357
public Request(StreamInput in) throws IOException {
@@ -56,6 +60,13 @@ public Request(StreamInput in) throws IOException {
5660
this.taskType = TaskType.fromStream(in);
5761
this.content = in.readBytesReference();
5862
this.contentType = in.readEnum(XContentType.class);
63+
64+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
65+
|| in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
66+
this.timeout = in.readTimeValue();
67+
} else {
68+
this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT;
69+
}
5970
}
6071

6172
public TaskType getTaskType() {
@@ -74,13 +85,22 @@ public XContentType getContentType() {
7485
return contentType;
7586
}
7687

88+
public TimeValue getTimeout() {
89+
return timeout;
90+
}
91+
7792
@Override
7893
public void writeTo(StreamOutput out) throws IOException {
7994
super.writeTo(out);
8095
out.writeString(inferenceEntityId);
8196
taskType.writeTo(out);
8297
out.writeBytesReference(content);
8398
XContentHelper.writeTo(out, contentType);
99+
100+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
101+
|| out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
102+
out.writeTimeValue(timeout);
103+
}
84104
}
85105

86106
@Override
@@ -105,12 +125,13 @@ public boolean equals(Object o) {
105125
return taskType == request.taskType
106126
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
107127
&& Objects.equals(content, request.content)
108-
&& contentType == request.contentType;
128+
&& contentType == request.contentType
129+
&& Objects.equals(timeout, request.timeout);
109130
}
110131

111132
@Override
112133
public int hashCode() {
113-
return Objects.hash(taskType, inferenceEntityId, content, contentType);
134+
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout);
114135
}
115136
}
116137

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,45 @@ public void setup() throws Exception {
3434

3535
public void testValidate() {
3636
// valid model ID
37-
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE);
37+
var request = new PutInferenceModelAction.Request(
38+
TASK_TYPE,
39+
MODEL_ID + "_-0",
40+
BYTES,
41+
X_CONTENT_TYPE,
42+
InferenceAction.Request.DEFAULT_TIMEOUT
43+
);
3844
ActionRequestValidationException validationException = request.validate();
3945
assertNull(validationException);
4046

4147
// invalid model IDs
4248

43-
var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE);
49+
var invalidRequest = new PutInferenceModelAction.Request(
50+
TASK_TYPE,
51+
"",
52+
BYTES,
53+
X_CONTENT_TYPE,
54+
InferenceAction.Request.DEFAULT_TIMEOUT
55+
);
4456
validationException = invalidRequest.validate();
4557
assertNotNull(validationException);
4658

4759
var invalidRequest2 = new PutInferenceModelAction.Request(
4860
TASK_TYPE,
4961
randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS),
5062
BYTES,
51-
X_CONTENT_TYPE
63+
X_CONTENT_TYPE,
64+
InferenceAction.Request.DEFAULT_TIMEOUT
5265
);
5366
validationException = invalidRequest2.validate();
5467
assertNotNull(validationException);
5568

56-
var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE);
69+
var invalidRequest3 = new PutInferenceModelAction.Request(
70+
TASK_TYPE,
71+
null,
72+
BYTES,
73+
X_CONTENT_TYPE,
74+
InferenceAction.Request.DEFAULT_TIMEOUT
75+
);
5776
validationException = invalidRequest3.validate();
5877
assertNotNull(validationException);
5978
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ protected void masterOperation(
177177
return;
178178
}
179179

180-
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
180+
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
181181
}
182182

183183
private void parseAndStoreModel(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121

2222
import static org.elasticsearch.rest.RestRequest.Method.PUT;
23+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
2324
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
2425
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
2526
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
@@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
4950
taskType = TaskType.ANY; // task type must be defined in the body
5051
}
5152

53+
var inferTimeout = parseTimeout(restRequest);
5254
var content = restRequest.requiredContent();
53-
var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
55+
var request = new PutInferenceModelAction.Request(
56+
taskType,
57+
inferenceEntityId,
58+
content,
59+
restRequest.getXContentType(),
60+
inferTimeout
61+
);
5462
return channel -> client.execute(
5563
PutInferenceModelAction.INSTANCE,
5664
request,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
package org.elasticsearch.xpack.inference.action;
99

10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1012
import org.elasticsearch.common.io.stream.Writeable;
1113
import org.elasticsearch.inference.TaskType;
12-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1314
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1416
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
17+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1518

16-
public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
19+
public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
1720
@Override
1821
protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
1922
return PutInferenceModelAction.Request::new;
@@ -25,38 +28,29 @@ protected PutInferenceModelAction.Request createTestInstance() {
2528
randomFrom(TaskType.values()),
2629
randomAlphaOfLength(6),
2730
randomBytesReference(50),
28-
randomFrom(XContentType.values())
31+
randomFrom(XContentType.values()),
32+
randomTimeValue()
2933
);
3034
}
3135

3236
@Override
3337
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
34-
return switch (randomIntBetween(0, 3)) {
35-
case 0 -> new PutInferenceModelAction.Request(
36-
TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
37-
instance.getInferenceEntityId(),
38-
instance.getContent(),
39-
instance.getContentType()
40-
);
41-
case 1 -> new PutInferenceModelAction.Request(
42-
instance.getTaskType(),
43-
instance.getInferenceEntityId() + "foo",
44-
instance.getContent(),
45-
instance.getContentType()
46-
);
47-
case 2 -> new PutInferenceModelAction.Request(
48-
instance.getTaskType(),
49-
instance.getInferenceEntityId(),
50-
randomBytesReference(instance.getContent().length() + 1),
51-
instance.getContentType()
52-
);
53-
case 3 -> new PutInferenceModelAction.Request(
38+
return randomValueOtherThan(instance, this::createTestInstance);
39+
}
40+
41+
@Override
42+
protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
43+
if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
44+
|| version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
45+
return instance;
46+
} else {
47+
return new PutInferenceModelAction.Request(
5448
instance.getTaskType(),
5549
instance.getInferenceEntityId(),
5650
instance.getContent(),
57-
XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
51+
instance.getContentType(),
52+
InferenceAction.Request.DEFAULT_TIMEOUT
5853
);
59-
default -> throw new IllegalStateException();
60-
};
54+
}
6155
}
6256
}

0 commit comments

Comments
 (0)