Skip to content
Merged
5 changes: 5 additions & 0 deletions docs/changelog/126805.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126805
summary: Adding timeout to request for creating inference endpoint
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ static TransportVersion def(int id) {
public static final TransportVersion INTRODUCE_FAILURES_LIFECYCLE_BACKPORT_8_19 = def(8_841_0_25);
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_28);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bringing the ESQL_REPORT_SHARD_PARTITIONING_8_19 inline with the value on the main branch: https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/TransportVersions.java#L176

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -15,6 +16,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -41,13 +43,15 @@ public static class Request extends AcknowledgedRequest<Request> {
private final String inferenceEntityId;
private final BytesReference content;
private final XContentType contentType;
private final TimeValue timeout;

public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.timeout = timeout;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -56,6 +60,12 @@ public Request(StreamInput in) throws IOException {
this.taskType = TaskType.fromStream(in);
this.content = in.readBytesReference();
this.contentType = in.readEnum(XContentType.class);

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
this.timeout = in.readTimeValue();
} else {
this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT;
}
}

public TaskType getTaskType() {
Expand All @@ -74,13 +84,21 @@ public XContentType getContentType() {
return contentType;
}

public TimeValue getTimeout() {
return timeout;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
taskType.writeTo(out);
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
out.writeTimeValue(timeout);
}
}

@Override
Expand All @@ -105,12 +123,13 @@ public boolean equals(Object o) {
return taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(content, request.content)
&& contentType == request.contentType;
&& contentType == request.contentType
&& Objects.equals(timeout, request.timeout);
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, content, contentType);
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,45 @@ public void setup() throws Exception {

public void testValidate() {
// valid model ID
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE);
var request = new PutInferenceModelAction.Request(
TASK_TYPE,
MODEL_ID + "_-0",
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
ActionRequestValidationException validationException = request.validate();
assertNull(validationException);

// invalid model IDs

var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE);
var invalidRequest = new PutInferenceModelAction.Request(
TASK_TYPE,
"",
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest.validate();
assertNotNull(validationException);

var invalidRequest2 = new PutInferenceModelAction.Request(
TASK_TYPE,
randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS),
BYTES,
X_CONTENT_TYPE
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest2.validate();
assertNotNull(validationException);

var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE);
var invalidRequest3 = new PutInferenceModelAction.Request(
TASK_TYPE,
null,
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest3.validate();
assertNotNull(validationException);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ protected void masterOperation(
return;
}

parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
}

private void parseAndStoreModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;

import static org.elasticsearch.rest.RestRequest.Method.PUT;
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
Expand Down Expand Up @@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
taskType = TaskType.ANY; // task type must be defined in the body
}

var inferTimeout = parseTimeout(restRequest);
var content = restRequest.requiredContent();
var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
var request = new PutInferenceModelAction.Request(
taskType,
inferenceEntityId,
content,
restRequest.getXContentType(),
inferTimeout
);
return channel -> client.execute(
PutInferenceModelAction.INSTANCE,
request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

package org.elasticsearch.xpack.inference.action;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
@Override
protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
return PutInferenceModelAction.Request::new;
Expand All @@ -25,38 +28,40 @@ protected PutInferenceModelAction.Request createTestInstance() {
randomFrom(TaskType.values()),
randomAlphaOfLength(6),
randomBytesReference(50),
randomFrom(XContentType.values())
randomFrom(XContentType.values()),
randomTimeValue()
);
}

@Override
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
return switch (randomIntBetween(0, 3)) {
case 0 -> new PutInferenceModelAction.Request(
TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
instance.getInferenceEntityId(),
instance.getContent(),
instance.getContentType()
);
case 1 -> new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId() + "foo",
instance.getContent(),
instance.getContentType()
);
case 2 -> new PutInferenceModelAction.Request(
return randomValueOtherThan(instance, this::createTestInstance);
}

@Override
protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
return instance;
} else if (version.onOrAfter(TransportVersions.V_8_0_0)) {
return new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
randomBytesReference(instance.getContent().length() + 1),
instance.getContentType()
instance.getContent(),
instance.getContentType(),
InferenceAction.Request.DEFAULT_TIMEOUT
);
case 3 -> new PutInferenceModelAction.Request(
} else {
return new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getContent(),
XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
/*
* See XContentHelper.java#L733
* for versions prior to 8.0.0, the content type does not have the VND_ instances
*/
XContentType.ofOrdinal(instance.getContentType().canonical().ordinal()),
InferenceAction.Request.DEFAULT_TIMEOUT
);
default -> throw new IllegalStateException();
};
}
}
}