From 653f77b5f4c24f3b280d2ca9524e5551260adcbe Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 14 Apr 2025 17:27:24 -0400 Subject: [PATCH 1/4] Fixing bug with listener and adding timeout --- .../org/elasticsearch/TransportVersions.java | 2 ++ .../action/PutInferenceModelAction.java | 23 ++++++++++++++++--- .../TransportPutInferenceModelAction.java | 2 +- .../rest/RestPutInferenceModelAction.java | 10 +++++++- .../BaseElasticsearchInternalService.java | 2 +- .../ElasticsearchInternalModel.java | 2 ++ 6 files changed, 35 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index e15377223aaf1..7efb9cbf4d8fc 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -159,6 +159,7 @@ static TransportVersion def(int id) { public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17); public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19); + public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_20); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -219,6 +220,7 @@ static TransportVersion def(int id) { public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0); public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_00_0); public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING = def(9_050_00_0); + public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_051_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java index a7f65c60a06c4..9c18e57376dc0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.inference.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -15,6 +16,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -41,13 +43,15 @@ public static class Request extends AcknowledgedRequest { private final String inferenceEntityId; private final BytesReference content; private final XContentType contentType; + private final TimeValue timeout; - public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) { + public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) { super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; this.content = content; this.contentType = contentType; + this.timeout = timeout; } public Request(StreamInput in) throws IOException { @@ -56,6 +60,13 @@ public Request(StreamInput in) throws IOException { this.taskType = TaskType.fromStream(in); this.content = in.readBytesReference(); this.contentType = in.readEnum(XContentType.class); + + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) + || in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { + this.timeout = in.readTimeValue(); + } else { + this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT; + } } public TaskType getTaskType() { @@ -74,6 +85,10 @@ public XContentType getContentType() { return contentType; } + public TimeValue getTimeout() { + return timeout; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -81,6 +96,7 @@ public void writeTo(StreamOutput out) throws IOException { taskType.writeTo(out); out.writeBytesReference(content); XContentHelper.writeTo(out, contentType); + out.writeTimeValue(timeout); } @Override @@ -105,12 +121,13 @@ public boolean equals(Object o) { return taskType == request.taskType && Objects.equals(inferenceEntityId, request.inferenceEntityId) && Objects.equals(content, request.content) - && contentType == request.contentType; + && contentType == request.contentType + && Objects.equals(timeout, request.timeout); } @Override public int hashCode() { - return Objects.hash(taskType, inferenceEntityId, content, contentType); + return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 4357fa619954c..e065fc4453a9e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -177,7 +177,7 @@ protected void masterOperation( return; } - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener); + parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener); } private void parseAndStoreModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java index 655e11996d522..838e6512d805f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java @@ -20,6 +20,7 @@ import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH; import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH; @@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient taskType = TaskType.ANY; // task type must be defined in the body } + var inferTimeout = parseTimeout(restRequest); var content = restRequest.requiredContent(); - var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType()); + var request = new PutInferenceModelAction.Request( + taskType, + inferenceEntityId, + content, + restRequest.getXContentType(), + inferTimeout + ); return channel -> client.execute( PutInferenceModelAction.INSTANCE, request, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 84259d1c0be66..a6823d65da107 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -106,7 +106,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL }) .andThen((l2, modelDidPut) -> { var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout); - var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener); + var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2); client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); }) .addListener(finalListener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index aa12bf0c645c3..f1011efd3b12c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -105,6 +105,8 @@ public void onFailure(Exception e) { && statusException.getRootCause() instanceof ResourceAlreadyExistsException) { // Deployment is already started listener.onResponse(Boolean.TRUE); + } else { + listener.onFailure(e); } return; } From def4c5140551be80852bd3c968f74ffb2f3738f2 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Mon, 14 Apr 2025 19:02:24 -0400 Subject: [PATCH 2/4] Update docs/changelog/126805.yaml --- docs/changelog/126805.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/126805.yaml diff --git a/docs/changelog/126805.yaml b/docs/changelog/126805.yaml new file mode 100644 index 0000000000000..ee9a4be7e4fd5 --- /dev/null +++ b/docs/changelog/126805.yaml @@ -0,0 +1,6 @@ +pr: 126805 +summary: Fixing bug with `TransportPutModelAction` listener and adding timeout to + request +area: Machine Learning +type: bug +issues: [] From 460b619353d0c591f3bc449bc0ecff4357912051 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 15 Apr 2025 14:41:39 -0400 Subject: [PATCH 3/4] Fixing tests --- .../action/PutInferenceModelActionTests.java | 27 +++++++++-- .../action/PutInferenceModelRequestTests.java | 46 ++++++++----------- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java index e0b04c6fe8769..f9f67167a12b1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java @@ -34,13 +34,25 @@ public void setup() throws Exception { public void testValidate() { // valid model ID - var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE); + var request = new PutInferenceModelAction.Request( + TASK_TYPE, + MODEL_ID + "_-0", + BYTES, + X_CONTENT_TYPE, + InferenceAction.Request.DEFAULT_TIMEOUT + ); ActionRequestValidationException validationException = request.validate(); assertNull(validationException); // invalid model IDs - var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE); + var invalidRequest = new PutInferenceModelAction.Request( + TASK_TYPE, + "", + BYTES, + X_CONTENT_TYPE, + InferenceAction.Request.DEFAULT_TIMEOUT + ); validationException = invalidRequest.validate(); assertNotNull(validationException); @@ -48,12 +60,19 @@ public void testValidate() { TASK_TYPE, randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS), BYTES, - X_CONTENT_TYPE + X_CONTENT_TYPE, + InferenceAction.Request.DEFAULT_TIMEOUT ); validationException = invalidRequest2.validate(); assertNotNull(validationException); - var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE); + var invalidRequest3 = new PutInferenceModelAction.Request( + TASK_TYPE, + null, + BYTES, + X_CONTENT_TYPE, + InferenceAction.Request.DEFAULT_TIMEOUT + ); validationException = invalidRequest3.validate(); assertNotNull(validationException); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java index f61398fcacacf..e514867780669 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java @@ -7,13 +7,16 @@ package org.elasticsearch.xpack.inference.action; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; -public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase { +public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase { @Override protected Writeable.Reader instanceReader() { return PutInferenceModelAction.Request::new; @@ -25,38 +28,29 @@ protected PutInferenceModelAction.Request createTestInstance() { randomFrom(TaskType.values()), randomAlphaOfLength(6), randomBytesReference(50), - randomFrom(XContentType.values()) + randomFrom(XContentType.values()), + randomTimeValue() ); } @Override protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) { - return switch (randomIntBetween(0, 3)) { - case 0 -> new PutInferenceModelAction.Request( - TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length], - instance.getInferenceEntityId(), - instance.getContent(), - instance.getContentType() - ); - case 1 -> new PutInferenceModelAction.Request( - instance.getTaskType(), - instance.getInferenceEntityId() + "foo", - instance.getContent(), - instance.getContentType() - ); - case 2 -> new PutInferenceModelAction.Request( - instance.getTaskType(), - instance.getInferenceEntityId(), - randomBytesReference(instance.getContent().length() + 1), - instance.getContentType() - ); - case 3 -> new PutInferenceModelAction.Request( + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) { + if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) + || version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { + return instance; + } else { + return new PutInferenceModelAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), instance.getContent(), - XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length] + instance.getContentType(), + InferenceAction.Request.DEFAULT_TIMEOUT ); - default -> throw new IllegalStateException(); - }; + } } } From 979fb51db43daa45928735c6b6a42ad6e4e4f2d0 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 15 Apr 2025 16:19:47 -0400 Subject: [PATCH 4/4] Fixing writeTo --- .../core/inference/action/PutInferenceModelAction.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java index 9c18e57376dc0..cded88c36388c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java @@ -96,7 +96,11 @@ public void writeTo(StreamOutput out) throws IOException { taskType.writeTo(out); out.writeBytesReference(content); XContentHelper.writeTo(out, contentType); - out.writeTimeValue(timeout); + + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) + || out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { + out.writeTimeValue(timeout); + } } @Override