Skip to content
6 changes: 6 additions & 0 deletions docs/changelog/126805.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126805
summary: Fixing bug with `TransportPutModelAction` listener and adding timeout to
request
area: Machine Learning
type: bug
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ static TransportVersion def(int id) {
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_20);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -223,6 +224,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING = def(9_050_0_00);
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION = def(9_051_0_00);
public static final TransportVersion ESQL_DOCUMENTS_FOUND_AND_VALUES_LOADED = def(9_052_0_00);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_053_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -15,6 +16,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -41,13 +43,15 @@ public static class Request extends AcknowledgedRequest<Request> {
private final String inferenceEntityId;
private final BytesReference content;
private final XContentType contentType;
private final TimeValue timeout;

public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.timeout = timeout;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -56,6 +60,13 @@ public Request(StreamInput in) throws IOException {
this.taskType = TaskType.fromStream(in);
this.content = in.readBytesReference();
this.contentType = in.readEnum(XContentType.class);

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
this.timeout = in.readTimeValue();
} else {
this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT;
}
}

public TaskType getTaskType() {
Expand All @@ -74,13 +85,22 @@ public XContentType getContentType() {
return contentType;
}

public TimeValue getTimeout() {
return timeout;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
taskType.writeTo(out);
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
out.writeTimeValue(timeout);
}
}

@Override
Expand All @@ -105,12 +125,13 @@ public boolean equals(Object o) {
return taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(content, request.content)
&& contentType == request.contentType;
&& contentType == request.contentType
&& Objects.equals(timeout, request.timeout);
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, content, contentType);
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,45 @@ public void setup() throws Exception {

public void testValidate() {
// valid model ID
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE);
var request = new PutInferenceModelAction.Request(
TASK_TYPE,
MODEL_ID + "_-0",
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
ActionRequestValidationException validationException = request.validate();
assertNull(validationException);

// invalid model IDs

var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE);
var invalidRequest = new PutInferenceModelAction.Request(
TASK_TYPE,
"",
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest.validate();
assertNotNull(validationException);

var invalidRequest2 = new PutInferenceModelAction.Request(
TASK_TYPE,
randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS),
BYTES,
X_CONTENT_TYPE
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest2.validate();
assertNotNull(validationException);

var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE);
var invalidRequest3 = new PutInferenceModelAction.Request(
TASK_TYPE,
null,
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest3.validate();
assertNotNull(validationException);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ protected void masterOperation(
return;
}

parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
}

private void parseAndStoreModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;

import static org.elasticsearch.rest.RestRequest.Method.PUT;
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
Expand Down Expand Up @@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
taskType = TaskType.ANY; // task type must be defined in the body
}

var inferTimeout = parseTimeout(restRequest);
var content = restRequest.requiredContent();
var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
var request = new PutInferenceModelAction.Request(
taskType,
inferenceEntityId,
content,
restRequest.getXContentType(),
inferTimeout
);
return channel -> client.execute(
PutInferenceModelAction.INSTANCE,
request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

package org.elasticsearch.xpack.inference.action;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
@Override
protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
return PutInferenceModelAction.Request::new;
Expand All @@ -25,38 +28,29 @@ protected PutInferenceModelAction.Request createTestInstance() {
randomFrom(TaskType.values()),
randomAlphaOfLength(6),
randomBytesReference(50),
randomFrom(XContentType.values())
randomFrom(XContentType.values()),
randomTimeValue()
);
}

@Override
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
return switch (randomIntBetween(0, 3)) {
case 0 -> new PutInferenceModelAction.Request(
TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
instance.getInferenceEntityId(),
instance.getContent(),
instance.getContentType()
);
case 1 -> new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId() + "foo",
instance.getContent(),
instance.getContentType()
);
case 2 -> new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
randomBytesReference(instance.getContent().length() + 1),
instance.getContentType()
);
case 3 -> new PutInferenceModelAction.Request(
return randomValueOtherThan(instance, this::createTestInstance);
}

@Override
protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
return instance;
} else {
return new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getContent(),
XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
instance.getContentType(),
InferenceAction.Request.DEFAULT_TIMEOUT
);
default -> throw new IllegalStateException();
};
}
}
}
Loading