diff --git a/docs/changelog/124025.yaml b/docs/changelog/124025.yaml new file mode 100644 index 0000000000000..8ec9a0fd1c537 --- /dev/null +++ b/docs/changelog/124025.yaml @@ -0,0 +1,5 @@ +pr: 124025 +summary: "[Inference API] Propagate product use case http header to EIS" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 574cc7d224a04..1a3d9f755ad76 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -192,6 +192,7 @@ static TransportVersion def(int id) { public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05); public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06); public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07); + public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceContext.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceContext.java new file mode 100644 index 0000000000000..a77d87f5fcd80 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceContext.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * Record for storing context alongside an inference request, typically used for metadata. + * This is mainly used to pass along inference context on the transport layer without relying on + * {@link org.elasticsearch.common.util.concurrent.ThreadContext}, which depending on the internal + * {@link org.elasticsearch.client.internal.Client} throws away parts of the context, when passed along the transport layer. + * + * @param productUseCase - for now mainly used by Elastic Inference Service + */ +public record InferenceContext(String productUseCase) implements Writeable, ToXContent { + + public static final InferenceContext EMPTY_INSTANCE = new InferenceContext(""); + + public InferenceContext { + Objects.requireNonNull(productUseCase); + } + + public InferenceContext(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(productUseCase); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field("product_use_case", productUseCase); + + builder.endObject(); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceContext that = (InferenceContext) o; + return Objects.equals(productUseCase, that.productUseCase); + } + + @Override + public int hashCode() { + return Objects.hashCode(productUseCase); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java index 855b0bdebb417..c74298bcd9346 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java @@ -12,8 +12,10 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.InferenceContext; import java.io.IOException; +import java.util.Objects; /** * Base class for inference action requests. Tracks request routing state to prevent potential routing loops @@ -23,8 +25,11 @@ public abstract class BaseInferenceActionRequest extends ActionRequest { private boolean hasBeenRerouted; - public BaseInferenceActionRequest() { + private final InferenceContext context; + + public BaseInferenceActionRequest(InferenceContext context) { super(); + this.context = context; } public BaseInferenceActionRequest(StreamInput in) throws IOException { @@ -36,6 +41,12 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException { // a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior. this.hasBeenRerouted = true; } + + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) { + this.context = new InferenceContext(in); + } else { + this.context = InferenceContext.EMPTY_INSTANCE; + } } public abstract boolean isStreaming(); @@ -52,11 +63,32 @@ public boolean hasBeenRerouted() { return hasBeenRerouted; } + public InferenceContext getContext() { + return context; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { out.writeBoolean(hasBeenRerouted); } + + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) { + context.writeTo(out); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BaseInferenceActionRequest that = (BaseInferenceActionRequest) o; + return hasBeenRerouted == that.hasBeenRerouted && Objects.equals(context, that.context); + } + + @Override + public int hashCode() { + return Objects.hash(hasBeenRerouted, context); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index dc177795af76a..6b449ff5a324f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -28,6 +28,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.InferenceContext; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; @@ -74,12 +75,14 @@ public static class Request extends BaseInferenceActionRequest { InputType.UNSPECIFIED ); - public static Builder parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException { + public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser) + throws IOException { Request.Builder builder = PARSER.apply(parser, null); builder.setInferenceEntityId(inferenceEntityId); builder.setTaskType(taskType); // For rest requests we won't know what the input type is builder.setInputType(InputType.UNSPECIFIED); + builder.setContext(context); return builder; } @@ -102,6 +105,31 @@ public Request( TimeValue inferenceTimeout, boolean stream ) { + this( + taskType, + inferenceEntityId, + query, + input, + taskSettings, + inputType, + inferenceTimeout, + stream, + InferenceContext.EMPTY_INSTANCE + ); + } + + public Request( + TaskType taskType, + String inferenceEntityId, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue inferenceTimeout, + boolean stream, + InferenceContext context + ) { + super(context); this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; this.query = query; @@ -241,19 +269,31 @@ static InputType getInputTypeToWrite(InputType inputType, TransportVersion versi public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; Request request = (Request) o; - return taskType == request.taskType + return stream == request.stream + && taskType == request.taskType && Objects.equals(inferenceEntityId, request.inferenceEntityId) + && Objects.equals(query, request.query) && Objects.equals(input, request.input) && Objects.equals(taskSettings, request.taskSettings) - && Objects.equals(inputType, request.inputType) - && Objects.equals(query, request.query) + && inputType == request.inputType && Objects.equals(inferenceTimeout, request.inferenceTimeout); } @Override public int hashCode() { - return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout); + return Objects.hash( + super.hashCode(), + taskType, + inferenceEntityId, + query, + input, + taskSettings, + inputType, + inferenceTimeout, + stream + ); } public static class Builder { @@ -266,6 +306,7 @@ public static class Builder { private String query; private TimeValue timeout = DEFAULT_TIMEOUT; private boolean stream = false; + private InferenceContext context; private Builder() {} @@ -313,8 +354,13 @@ public Builder setStream(boolean stream) { return this; } + public Builder setContext(InferenceContext context) { + this.context = context; + return this; + } + public Request build() { - return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream); + return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context); } } @@ -333,6 +379,8 @@ public String toString() { + this.getInputType() + ", timeout=" + this.getInferenceTimeout() + + ", context=" + + this.getContext() + ")"; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java index 68cd39f26b456..303a7568fe680 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.inference.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; @@ -17,6 +18,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.InferenceContext; import java.io.IOException; import java.util.Objects; @@ -44,6 +46,7 @@ public static class Request extends ActionRequest { private final XContentType contentType; private final TimeValue timeout; private final boolean stream; + private final InferenceContext context; public Request( TaskType taskType, @@ -51,7 +54,8 @@ public Request( BytesReference content, XContentType contentType, TimeValue timeout, - boolean stream + boolean stream, + InferenceContext context ) { this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; @@ -59,6 +63,7 @@ public Request( this.contentType = contentType; this.timeout = timeout; this.stream = stream; + this.context = context; } public Request(StreamInput in) throws IOException { @@ -71,6 +76,12 @@ public Request(StreamInput in) throws IOException { // streaming is not supported yet for transport traffic this.stream = false; + + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) { + this.context = new InferenceContext(in); + } else { + this.context = InferenceContext.EMPTY_INSTANCE; + } } public TaskType getTaskType() { @@ -97,6 +108,10 @@ public boolean isStreaming() { return stream; } + public InferenceContext getContext() { + return context; + } + @Override public ActionRequestValidationException validate() { return null; @@ -110,6 +125,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBytesReference(content); XContentHelper.writeTo(out, contentType); out.writeTimeValue(timeout); + + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) { + context.writeTo(out); + } } @Override @@ -122,12 +141,13 @@ public boolean equals(Object o) { && Objects.equals(content, request.content) && contentType == request.contentType && timeout == request.timeout - && stream == request.stream; + && stream == request.stream + && context == request.context; } @Override public int hashCode() { - return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream); + return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream, context); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java index 43c84ad914c2a..8c27667316059 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.InferenceContext; import java.io.IOException; import java.util.Objects; @@ -28,10 +29,15 @@ public UnifiedCompletionAction() { } public static class Request extends BaseInferenceActionRequest { - public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) - throws IOException { + public static Request parseRequest( + String inferenceEntityId, + TaskType taskType, + TimeValue timeout, + InferenceContext context, + XContentParser parser + ) throws IOException { var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null); - return new Request(inferenceEntityId, taskType, unifiedRequest, timeout); + return new Request(inferenceEntityId, taskType, unifiedRequest, context, timeout); } private final String inferenceEntityId; @@ -40,6 +46,17 @@ public static Request parseRequest(String inferenceEntityId, TaskType taskType, private final TimeValue timeout; public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest, TimeValue timeout) { + this(inferenceEntityId, taskType, unifiedCompletionRequest, InferenceContext.EMPTY_INSTANCE, timeout); + } + + public Request( + String inferenceEntityId, + TaskType taskType, + UnifiedCompletionRequest unifiedCompletionRequest, + InferenceContext context, + TimeValue timeout + ) { + super(context); this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); this.taskType = Objects.requireNonNull(taskType); this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceContextTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceContextTests.java new file mode 100644 index 0000000000000..4aea4e07171dc --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceContextTests.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; + +public class InferenceContextTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return InferenceContext::new; + } + + @Override + protected InferenceContext createTestInstance() { + return new InferenceContext(randomAlphaOfLength(10)); + } + + @Override + protected InferenceContext mutateInstance(InferenceContext instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index e9f4df7a523ad..537a7544f0d94 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.InferenceContext; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; @@ -26,6 +27,7 @@ import java.util.Map; import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.getInputTypeToWrite; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.collection.IsIterableContainingInOrder.contains; @@ -46,7 +48,8 @@ protected InferenceAction.Request createTestInstance() { randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), TimeValue.timeValueMillis(randomLongBetween(1, 2048)), - false + false, + new InferenceContext(randomAlphanumericOfLength(10)) ); } @@ -57,7 +60,12 @@ public void testParsing() throws IOException { } """; try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) { - var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser).build(); + var request = InferenceAction.Request.parseRequest( + "model_id", + TaskType.SPARSE_EMBEDDING, + InferenceContext.EMPTY_INSTANCE, + parser + ).build(); assertThat(request.getInput(), contains("single text input")); } @@ -67,7 +75,7 @@ public void testParsing() throws IOException { } """; try (var parser = createParser(JsonXContent.jsonXContent, multiInputRequest)) { - var request = InferenceAction.Request.parseRequest("model_id", TaskType.ANY, parser).build(); + var request = InferenceAction.Request.parseRequest("model_id", TaskType.ANY, InferenceContext.EMPTY_INSTANCE, parser).build(); assertThat(request.getInput(), contains("an array", "of", "inputs")); } } @@ -173,14 +181,19 @@ public void testParseRequest_DefaultsInputTypeToIngest() throws IOException { } """; try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) { - var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser).build(); + var request = InferenceAction.Request.parseRequest( + "model_id", + TaskType.SPARSE_EMBEDDING, + InferenceContext.EMPTY_INSTANCE, + parser + ).build(); assertThat(request.getInputType(), is(InputType.UNSPECIFIED)); } } @Override protected InferenceAction.Request mutateInstance(InferenceAction.Request instance) throws IOException { - int select = randomIntBetween(0, 6); + int select = randomIntBetween(0, 7); return switch (select) { case 0 -> { var nextTask = TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length]; @@ -192,7 +205,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), - false + false, + instance.getContext() ); } case 1 -> new InferenceAction.Request( @@ -203,7 +217,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), - false + false, + instance.getContext() ); case 2 -> { var changedInputs = new ArrayList(instance.getInput()); @@ -216,7 +231,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), - false + false, + instance.getContext() ); } case 3 -> { @@ -235,7 +251,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc taskSettings, instance.getInputType(), instance.getInferenceTimeout(), - false + false, + instance.getContext() ); } case 4 -> { @@ -248,7 +265,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), nextInputType, instance.getInferenceTimeout(), - false + false, + instance.getContext() ); } case 5 -> new InferenceAction.Request( @@ -259,7 +277,8 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), - false + false, + instance.getContext() ); case 6 -> { var newDuration = Duration.of( @@ -275,7 +294,22 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()), - false + false, + instance.getContext() + ); + } + case 7 -> { + var newContext = new InferenceContext(instance.getContext().productUseCase() + randomAlphaOfLength(5)); + yield new InferenceAction.Request( + instance.getTaskType(), + instance.getInferenceEntityId(), + instance.getQuery(), + instance.getInput(), + instance.getTaskSettings(), + instance.getInputType(), + instance.getInferenceTimeout(), + instance.isStreaming(), + newContext ); } default -> throw new UnsupportedOperationException(); @@ -284,8 +318,10 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc @Override protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Request instance, TransportVersion version) { + InferenceAction.Request mutated; + if (version.before(TransportVersions.V_8_12_0)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -296,7 +332,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque false ); } else if (version.before(TransportVersions.V_8_13_0)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -310,7 +346,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque && (instance.getInputType() == InputType.UNSPECIFIED || instance.getInputType() == InputType.CLASSIFICATION || instance.getInputType() == InputType.CLUSTERING)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -322,7 +358,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque ); } else if (version.before(TransportVersions.V_8_13_0) && (instance.getInputType() == InputType.CLUSTERING || instance.getInputType() == InputType.CLASSIFICATION)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -333,7 +369,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque false ); } else if (version.before(TransportVersions.V_8_14_0)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -343,25 +379,52 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque InferenceAction.Request.DEFAULT_TIMEOUT, false ); + } else if (version.before(TransportVersions.INFERENCE_CONTEXT_8_X)) { + mutated = new InferenceAction.Request( + instance.getTaskType(), + instance.getInferenceEntityId(), + instance.getQuery(), + instance.getInput(), + instance.getTaskSettings(), + instance.getInputType(), + instance.getInferenceTimeout(), + false, + InferenceContext.EMPTY_INSTANCE + ); + } else { + mutated = instance; } - return instance; + // We always assume that a request has been rerouted, if it came from a node before adaptive rate limiting + if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { + mutated.setHasBeenRerouted(true); + } else { + mutated.setHasBeenRerouted(instance.hasBeenRerouted()); + } + + return mutated; } public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOException { - assertBwcSerialization( - new InferenceAction.Request( - TaskType.TEXT_EMBEDDING, - "model", - null, - List.of(), - Map.of(), - InputType.UNSPECIFIED, - InferenceAction.Request.DEFAULT_TIMEOUT, - false - ), + InferenceAction.Request instance = new InferenceAction.Request( + TaskType.TEXT_EMBEDDING, + "model", + null, + List.of(), + Map.of(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + false + ); + + InferenceAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), TransportVersions.V_8_13_0 ); + + assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED)); } public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUnspecified() throws IOException { @@ -409,6 +472,30 @@ public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeen assertTrue(deserializedInstance.hasBeenRerouted()); } + public void testWriteTo_WhenVersionIsBeforeInferenceContext_ShouldSetContextToEmptyContext() throws IOException { + var instance = new InferenceAction.Request( + TaskType.TEXT_EMBEDDING, + "model", + null, + List.of("input"), + Map.of(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + false, + new InferenceContext(randomAlphaOfLength(10)) + ); + + InferenceAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.V_8_15_0 + ); + + // Verify that context is empty after deserializing a request coming from an older transport version + assertThat(deserializedInstance.getContext(), equalTo(InferenceContext.EMPTY_INSTANCE)); + } + public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsUnspecified_VersionBeforeUnspecifiedIntroduced() { assertThat(getInputTypeToWrite(InputType.UNSPECIFIED, TransportVersions.V_8_12_1), is(InputType.INGEST)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java index ceb7c9853a0f4..c76e187f2fbb1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -14,11 +14,13 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.inference.InferenceContext; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import java.util.List; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; public class UnifiedCompletionActionRequestTests extends AbstractBWCWireSerializationTestCase { @@ -85,8 +87,36 @@ public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeen assertTrue(deserializedInstance.hasBeenRerouted()); } + public void testWriteTo_WhenVersionIsBeforeInferenceContext_ShouldSetContextToEmptyContext() throws IOException { + var instance = new UnifiedCompletionAction.Request( + "model", + TaskType.ANY, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + InferenceContext.EMPTY_INSTANCE, + TimeValue.timeValueSeconds(10) + ); + + UnifiedCompletionAction.Request deserializedInstance = copyWriteable( + instance, + getNamedWriteableRegistry(), + instanceReader(), + TransportVersions.ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION + ); + assertThat(deserializedInstance.getContext(), equalTo(InferenceContext.EMPTY_INSTANCE)); + } + @Override protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) { + if (version.before(TransportVersions.INFERENCE_CONTEXT_8_X)) { + return new UnifiedCompletionAction.Request( + instance.getInferenceEntityId(), + instance.getTaskType(), + instance.getUnifiedCompletionRequest(), + InferenceContext.EMPTY_INSTANCE, + instance.getTimeout() + ); + } + return instance; } @@ -101,6 +131,7 @@ protected UnifiedCompletionAction.Request createTestInstance() { randomAlphaOfLength(10), randomFrom(TaskType.values()), UnifiedCompletionRequestTests.randomUnifiedCompletionRequest(), + InferenceContext.EMPTY_INSTANCE, TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 1a0a37f1fa2aa..de83af90f20f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -44,6 +44,7 @@ import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestHeaderDefinition; import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; @@ -135,6 +136,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Predicate; import java.util.function.Supplier; @@ -173,6 +175,8 @@ public class InferencePlugin extends Plugin License.OperationMode.ENTERPRISE ); + public static final String X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER = "X-elastic-product-use-case"; + public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; @@ -526,6 +530,16 @@ public void onNodeStarted() { } } + @Override + public Collection getRestHeaders() { + return Set.of(new RestHeaderDefinition(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, false)); + } + + @Override + public Collection getTaskHeaders() { + return Set.of(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER); + } + protected SSLService getSslService() { return XPackPlugin.getSharedSslService(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 2417561cc4497..fb3c0729816c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -47,6 +47,7 @@ import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; import java.io.IOException; +import java.util.Objects; import java.util.Random; import java.util.concurrent.Executor; import java.util.concurrent.Flow; @@ -143,6 +144,13 @@ protected void doExecute(Task task, Request request, ActionListener listener) { - if (request.isStreaming() == false || service.canStream(model.getTaskType())) { + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { doInference(model, request, service, listener); } else { listener.onFailure(unsupportedStreamingTaskException(request, service)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java index 6d46f834d4873..88a9927db2d9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -98,6 +98,7 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, request.getInferenceEntityId(), request.getTaskType(), request.getTimeout(), + request.getContext(), parser ); } @@ -115,6 +116,7 @@ private void sendInferenceActionRequest(InferenceActionProxy.Request request, Ac inferenceActionRequestBuilder = InferenceAction.Request.parseRequest( request.getInferenceEntityId(), request.getTaskType(), + request.getContext(), parser ); inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceRequestManager.java index c857a481f8f04..7cc6c3def9637 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceRequestManager.java @@ -8,14 +8,24 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestMetadata; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; import java.util.Objects; +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; + public abstract class ElasticInferenceServiceRequestManager extends BaseRequestManager { + private final ElasticInferenceServiceRequestMetadata requestMetadata; + protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) { super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + this.requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); + } + + public ElasticInferenceServiceRequestMetadata requestMetadata() { + return requestMetadata; } record RateLimitGrouping(int modelIdHash) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index c647b3aea4771..edd4b651d40bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; @@ -44,8 +43,6 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast private final InputType inputType; - private final String productOrigin; - private static ResponseHandler createSparseEmbeddingsHandler() { return new ElasticInferenceServiceResponseHandler( String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), @@ -63,7 +60,6 @@ public ElasticInferenceServiceSparseEmbeddingsRequestManager( this.model = model; this.truncator = serviceComponents.truncator(); this.traceContext = traceContext; - this.productOrigin = serviceComponents.threadPool().getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); this.inputType = inputType; } @@ -82,9 +78,10 @@ public void execute( truncatedInput, model, traceContext, - productOrigin, + requestMetadata(), inputType ); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java index 65e7e1704e37b..6e33008f22ea1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; @@ -44,7 +43,6 @@ public static ElasticInferenceServiceUnifiedCompletionRequestManager of( private final ElasticInferenceServiceCompletionModel model; private final TraceContext traceContext; - private final String productOrigin; private ElasticInferenceServiceUnifiedCompletionRequestManager( ElasticInferenceServiceCompletionModel model, @@ -54,7 +52,6 @@ private ElasticInferenceServiceUnifiedCompletionRequestManager( super(threadPool, model); this.model = model; this.traceContext = traceContext; - this.productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); } @Override @@ -69,7 +66,7 @@ public void execute( inferenceInputs.castTo(UnifiedChatInput.class), model, traceContext, - productOrigin + requestMetadata() ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java index f8dfbd1587b2e..5654c16764620 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java @@ -25,8 +25,12 @@ public class ElasticInferenceServiceAuthorizationRequest extends ElasticInferenc private final URI uri; private final TraceContextHandler traceContextHandler; - public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext, String productOrigin) { - super(productOrigin); + public ElasticInferenceServiceAuthorizationRequest( + String url, + TraceContext traceContext, + ElasticInferenceServiceRequestMetadata requestMetadata + ) { + super(requestMetadata); this.uri = createUri(Objects.requireNonNull(url)); this.traceContextHandler = new TraceContextHandler(traceContext); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java index cd152751499b4..9c85ebeb3bdd6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java @@ -8,29 +8,54 @@ package org.elasticsearch.xpack.inference.external.request.elastic; import org.apache.http.client.methods.HttpRequestBase; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; + public abstract class ElasticInferenceServiceRequest implements Request { - private final String productOrigin; + private final ElasticInferenceServiceRequestMetadata metadata; - public ElasticInferenceServiceRequest(String productOrigin) { - this.productOrigin = productOrigin; + public ElasticInferenceServiceRequest(ElasticInferenceServiceRequestMetadata metadata) { + this.metadata = metadata; } - public String getProductOrigin() { - return productOrigin; + public ElasticInferenceServiceRequestMetadata getMetadata() { + return metadata; } @Override public final HttpRequest createHttpRequest() { HttpRequestBase request = createHttpRequestBase(); // TODO: consider moving tracing here, too - request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, productOrigin); + + var productOrigin = metadata.productOrigin(); + var productUseCase = metadata.productUseCase(); + + if (Objects.nonNull(productOrigin) && productOrigin.isEmpty() == false) { + request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, metadata.productOrigin()); + } + + if (Objects.nonNull(productUseCase) && productUseCase.isEmpty() == false) { + request.setHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, metadata.productUseCase()); + } + return new HttpRequest(request, getInferenceEntityId()); } protected abstract HttpRequestBase createHttpRequestBase(); + + public static ElasticInferenceServiceRequestMetadata extractRequestMetadataFromThreadContext(ThreadContext context) { + // 'X-Elastic-Product-Origin' is an Elastic wide header and therefore present in the ES-wide generic Task class. + // 'X-Elastic-Product-Use-Case' is Elastic Inference Service specific and is therefore not propagated through the ES-wide Task. + return new ElasticInferenceServiceRequestMetadata( + context.getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER), + context.getHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER) + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestMetadata.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestMetadata.java new file mode 100644 index 0000000000000..3e6efaa9e4653 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestMetadata.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +/** + * Record encapsulating arbitrary metadata, which is usually propagated through HTTP headers. + * @param productOrigin - product origin of the inference request (usually a whole system like "kibana", "logstash" etc.) + * @param productUseCase - product use case of the inference request (more granular view on a user flow like "security ai assistant" etc.) + */ +public record ElasticInferenceServiceRequestMetadata(String productOrigin, String productUseCase) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java index af44a6379f961..fb44f8667b4ec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java @@ -40,10 +40,10 @@ public ElasticInferenceServiceSparseEmbeddingsRequest( Truncator.TruncationResult truncationResult, ElasticInferenceServiceSparseEmbeddingsModel model, TraceContext traceContext, - String productOrigin, + ElasticInferenceServiceRequestMetadata metadata, InputType inputType ) { - super(productOrigin); + super(metadata); this.truncator = truncator; this.truncationResult = truncationResult; this.model = Objects.requireNonNull(model); @@ -95,7 +95,7 @@ public Request truncate() { truncatedInput, model, traceContextHandler.traceContext(), - getProductOrigin(), + getMetadata(), inputType ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java index 6610b1f38a4dc..e6e900d7cd515 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java @@ -34,9 +34,9 @@ public ElasticInferenceServiceUnifiedChatCompletionRequest( UnifiedChatInput unifiedChatInput, ElasticInferenceServiceCompletionModel model, TraceContext traceContext, - String productOrigin + ElasticInferenceServiceRequestMetadata requestMetadata ) { - super(productOrigin); + super(requestMetadata); this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.model = Objects.requireNonNull(model); this.traceContextHandler = new TraceContextHandler(traceContext); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java index 06a0849b91d4e..29e656d4a8d95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -14,10 +14,13 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.xpack.core.inference.InferenceContext; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; +import org.elasticsearch.xpack.inference.InferencePlugin; import java.io.IOException; +import java.util.Objects; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; @@ -44,6 +47,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient var params = parseParams(restRequest); var content = restRequest.requiredContent(); var inferTimeout = parseTimeout(restRequest); + var productUseCase = extractProductUseCase(restRequest); + var context = new InferenceContext(productUseCase); var request = new InferenceActionProxy.Request( params.taskType(), @@ -51,7 +56,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient content, restRequest.getXContentType(), inferTimeout, - shouldStream() + shouldStream(), + context ); return channel -> client.execute(InferenceActionProxy.INSTANCE, request, ActionListener.withRef(listener(channel), content)); @@ -60,4 +66,21 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient protected abstract boolean shouldStream(); protected abstract ActionListener listener(RestChannel channel); + + private String extractProductUseCase(RestRequest restRequest) { + var headers = restRequest.getHeaders(); + + if (Objects.isNull(headers) || headers.isEmpty()) { + return ""; + } + + var productUseCaseHeaders = headers.get(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER); + + if (Objects.isNull(productUseCaseHeaders) || productUseCaseHeaders.isEmpty()) { + return ""; + } + + // We always get the first value as the header doesn't allow multiple values + return productUseCaseHeaders.get(0); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index eb92d1b48f8a7..5fe511fcb846c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -29,6 +29,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; /** @@ -108,8 +109,8 @@ public void getAuthorization(ActionListener action; + private ThreadPool threadPool; protected static final String serviceId = "serviceId"; protected final TaskType taskType; @@ -76,7 +80,7 @@ public BaseTransportInferenceActionTestCase(TaskType taskType) { public void setUp() throws Exception { super.setUp(); ActionFilters actionFilters = mock(); - ThreadPool threadPool = mock(); + threadPool = mock(); nodeClient = mock(); transportService = mock(); inferenceServiceRateLimitCalculator = mock(); @@ -332,6 +336,38 @@ public void onComplete() { })); } + public void testProductUseCaseHeaderPresentInThreadContextIfPresent() { + String productUseCase = "product-use-case"; + + // We need to use real instances instead of mocks as these are final classes + InferenceContext context = new InferenceContext(productUseCase); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(threadPool.getThreadContext()).thenReturn(threadContext); + + mockModelRegistry(taskType); + mockService(listener -> listener.onResponse(mock())); + + Request request = createRequest(); + when(request.getContext()).thenReturn(context); + when(request.getInferenceEntityId()).thenReturn(inferenceId); + when(request.getTaskType()).thenReturn(taskType); + when(request.isStreaming()).thenReturn(false); + + ActionListener listener = spy(new ActionListener<>() { + @Override + public void onResponse(InferenceAction.Response o) {} + + @Override + public void onFailure(Exception e) {} + }); + + action.doExecute(mock(), request, listener); + + // Verify the product use case header was set in the thread context + assertThat(threadContext.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + } + protected Flow.Publisher mockStreamResponse(Consumer> action) { mockService(true, Set.of(), listener -> { Flow.Processor taskProcessor = mock(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java index a9e6ec55a6224..1137027e8d362 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.InferenceContext; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; @@ -87,7 +88,8 @@ public void testExecutesAUnifiedCompletionRequest_WhenTaskTypeIsChatCompletion_I new BytesArray(requestJson), XContentType.JSON, TimeValue.ONE_MINUTE, - true + true, + InferenceContext.EMPTY_INSTANCE ); action.doExecute(mock(Task.class), request, listener); @@ -129,7 +131,8 @@ public void testExecutesAUnifiedCompletionRequest_WhenTaskTypeIsChatCompletion_F new BytesArray(requestJson), XContentType.JSON, TimeValue.ONE_MINUTE, - true + true, + InferenceContext.EMPTY_INSTANCE ); action.doExecute(mock(Task.class), request, listener); @@ -152,7 +155,8 @@ public void testExecutesAnInferenceAction_WhenTaskTypeIsCompletion_InRequest() { new BytesArray(requestJson), XContentType.JSON, TimeValue.ONE_MINUTE, - true + true, + InferenceContext.EMPTY_INSTANCE ); action.doExecute(mock(Task.class), request, listener); @@ -181,7 +185,8 @@ public void testExecutesAnInferenceAction_WhenTaskTypeIsCompletion_FromStorage() new BytesArray(requestJson), XContentType.JSON, TimeValue.ONE_MINUTE, - true + true, + InferenceContext.EMPTY_INSTANCE ); action.doExecute(mock(Task.class), request, listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 3129f0865a249..144781678fe33 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -70,7 +70,7 @@ protected BaseTransportInferenceAction createAction( @Override protected InferenceAction.Request createRequest() { - return mock(); + return mock(InferenceAction.Request.class); } public void testNoRerouting_WhenTaskTypeNotSupported() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index 7dac6a1015aae..f26d0675487a5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -69,7 +69,7 @@ protected BaseTransportInferenceAction createAc @Override protected UnifiedCompletionAction.Request createRequest() { - return mock(); + return mock(UnifiedCompletionAction.Request.class); } public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInferenceEndpoint() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 5a8c66da1dfa3..0c296fc5729bd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -49,6 +49,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.openai.OpenAiUtils.ORGANIZATION_HEADER; +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.hamcrest.Matchers.equalTo; @@ -161,7 +162,7 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce var request = new ElasticInferenceServiceAuthorizationRequest( getUrl(webServer), new TraceContext("", ""), - randomAlphaOfLength(10) + randomElasticInferenceServiceRequestMetadata() ); var responseHandler = new ElasticInferenceServiceResponseHandler( String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java index bc79881505167..fcb3882bf8771 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.Before; +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; @@ -30,7 +31,7 @@ public void testCreateUriThrowsForInvalidBaseUrl() { ElasticsearchStatusException exception = assertThrows( ElasticsearchStatusException.class, - () -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext, randomAlphaOfLength(10)) + () -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext, randomElasticInferenceServiceRequestMetadata()) ); assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java index c5dd19a045390..5fd8083ca75bb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java @@ -15,23 +15,41 @@ import java.net.URI; +import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; import static org.hamcrest.Matchers.equalTo; public class ElasticInferenceServiceRequestTests extends ESTestCase { public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithProductOrigin() { var productOrigin = "elastic"; - var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest(productOrigin); + var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest( + new ElasticInferenceServiceRequestMetadata(productOrigin, null) + ); var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest(); var productOriginHeader = httpRequest.httpRequestBase().getFirstHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); - // Make sure this header only exists once + // Make sure the product origin header only exists once assertThat(httpRequest.httpRequestBase().getHeaders(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER).length, equalTo(1)); assertThat(productOriginHeader.getValue(), equalTo(productOrigin)); } - private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest(String productOrigin) { - return new ElasticInferenceServiceRequest(productOrigin) { + public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithProductUseCase() { + var productUseCase = "ai assistant"; + var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest( + new ElasticInferenceServiceRequestMetadata(null, productUseCase) + ); + var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest(); + var productUseCaseHeader = httpRequest.httpRequestBase().getFirstHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER); + + // Make sure the product use case header only exists once + assertThat(httpRequest.httpRequestBase().getHeaders(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER).length, equalTo(1)); + assertThat(productUseCaseHeader.getValue(), equalTo(productUseCase)); + } + + private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest( + ElasticInferenceServiceRequestMetadata requestMetadata + ) { + return new ElasticInferenceServiceRequest(requestMetadata) { @Override protected HttpRequestBase createHttpRequestBase() { return new HttpGet("http://localhost:8080"); @@ -58,4 +76,11 @@ public String getInferenceEntityId() { } }; } + + public static ElasticInferenceServiceRequestMetadata randomElasticInferenceServiceRequestMetadata() { + return new ElasticInferenceServiceRequestMetadata( + randomFrom(new String[] { null, randomAlphaOfLength(10) }), + randomFrom(new String[] { null, randomAlphaOfLength(10) }) + ); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java index 794c05f6f8d50..d5137ec5d0709 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java @@ -23,6 +23,7 @@ import java.util.List; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.equalTo; @@ -124,7 +125,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, new Truncator.TruncationResult(List.of(input), new boolean[] { false }), embeddingsModel, new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomAlphaOfLength(10), + randomElasticInferenceServiceRequestMetadata(), inputType ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 91eb4a5b9af99..4e60b09530684 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -20,9 +20,11 @@ import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.InferenceContext; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.junit.Before; import java.util.HashMap; @@ -146,6 +148,90 @@ public void testUses3SecondTimeoutFromParams() { assertThat(executeCalled.get(), equalTo(true)); } + public void testExtractProductUseCase() { + SetOnce executeCalled = new SetOnce<>(); + String productUseCase = "product-use-case"; + + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); + + var request = (InferenceActionProxy.Request) actionRequest; + InferenceContext context = request.getContext(); + assertNotNull(context); + assertThat(context.productUseCase(), equalTo(productUseCase)); + + executeCalled.set(true); + return createResponse(); + })); + + // Create a request with the product use case header + Map> headers = new HashMap<>(); + headers.put(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, List.of(productUseCase)); + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(route("test")) + .withHeaders(headers) + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } + + public void testExtractProductUseCase_EmptyWhenHeaderMissing() { + SetOnce executeCalled = new SetOnce<>(); + + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); + + var request = (InferenceActionProxy.Request) actionRequest; + InferenceContext context = request.getContext(); + assertNotNull(context); + assertThat(context.productUseCase(), equalTo("")); + + executeCalled.set(true); + return createResponse(); + })); + + // Create a request without the product use case header + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(route("test")) + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } + + public void testExtractProductUseCase_EmptyWhenHeaderValueEmpty() { + SetOnce executeCalled = new SetOnce<>(); + + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(InferenceActionProxy.Request.class)); + + var request = (InferenceActionProxy.Request) actionRequest; + InferenceContext context = request.getContext(); + assertNotNull(context); + assertThat(context.productUseCase(), equalTo("")); + + executeCalled.set(true); + return createResponse(); + })); + + // Create a request with an empty product use case header value + Map> headers = new HashMap<>(); + headers.put(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, List.of("")); + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(route("test")) + .withHeaders(headers) + .withContent(new BytesArray("{}"), XContentType.JSON) + .build(); + + dispatchRequest(inferenceRequest); + assertThat(executeCalled.get(), equalTo(true)); + } + static InferenceAction.Response createResponse() { return new InferenceAction.Response( new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1 }))) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 9cd2c50fabfc6..ccd4d65a87363 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -478,9 +479,9 @@ public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws public void testInfer_SendsEmbeddingsRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var eisGatewayUrl = getUrl(webServer); + var elasticInferenceServiceURL = getUrl(webServer); - try (var service = createService(senderFactory, eisGatewayUrl)) { + try (var service = createService(senderFactory, elasticInferenceServiceURL)) { String responseJson = """ { "data": [ @@ -494,7 +495,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(eisGatewayUrl, "my-model-id"); + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -527,11 +528,213 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { } } + public void testInfer_PropagatesProductUseCaseHeader() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var elasticInferenceServiceURL = getUrl(webServer); + + try (var service = createService(senderFactory, elasticInferenceServiceURL)) { + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // Set up the product use case in the thread context + String productUseCase = "test-product-use-case"; + threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); + + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); + PlainActionFuture listener = new PlainActionFuture<>(); + + try { + service.infer( + model, + null, + List.of("input text"), + false, + new HashMap<>(), + InputType.SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + var result = listener.actionGet(TIMEOUT); + + // Verify the response was processed correctly + assertThat( + result.asMap(), + Matchers.is( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( + List.of( + new SparseEmbeddingResultsTests.EmbeddingExpectation( + Map.of("hello", 2.1259406f, "greet", 1.7073475f), + false + ) + ) + ) + ) + ); + + // Verify the header was sent in the request + var request = webServer.requests().get(0); + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); + + // Check that the product use case header was set correctly + assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + + // Verify request body + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "search"))); + } finally { + // Clean up the thread context + threadPool.getThreadContext().stashContext(); + } + } + } + + public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var elasticInferenceServiceURL = getUrl(webServer); + + try (var service = createService(senderFactory, elasticInferenceServiceURL)) { + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // Set up the product use case in the thread context + String productUseCase = "test-product-use-case"; + threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); + + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); + PlainActionFuture> listener = new PlainActionFuture<>(); + + try { + service.chunkedInfer( + model, + null, + List.of("input text"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + + // Verify the response was processed correctly + ChunkedInference inferenceResult = results.get(0); + assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class)); + var sparseResult = (ChunkedInferenceEmbedding) inferenceResult; + assertThat( + sparseResult.chunks(), + is( + List.of( + new SparseEmbeddingResults.Chunk( + List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), + new ChunkedInference.TextOffset(0, "input text".length()) + ) + ) + ) + ); + + // Verify the request was sent and contains expected headers + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + var request = webServer.requests().get(0); + assertNull(request.getUri().getQuery()); + MatcherAssert.assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + // Check that the product use case header was set correctly + assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + + // Verify request body + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest"))); + } finally { + // Clean up the thread context + threadPool.getThreadContext().stashContext(); + } + } + } + + public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException { + var elasticInferenceServiceURL = getUrl(webServer); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = createService(senderFactory, elasticInferenceServiceURL)) { + // Mock a successful streaming response + String responseJson = """ + data: {"id":"1","object":"completion","created":1677858242,"model":"my-model-id", + "choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":"Hello"}}]} + + data: {"id":"2","object":"completion","created":1677858242,"model":"my-model-id", + "choices":[{"finish_reason":"stop","index":0,"delta":{"content":" world!"}}]} + + data: [DONE] + + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + String productUseCase = "test-product-use-case"; + threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); + + // Create completion model + var model = new ElasticInferenceServiceCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings("my-model-id", new RateLimitSettings(100)), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(elasticInferenceServiceURL) + ); + + var request = UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello"), "user", null, null)) + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + + try { + service.unifiedCompletionInfer(model, request, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + // We don't need to check the actual response as we're only testing header propagation + listener.actionGet(TIMEOUT); + + // Verify the request was sent + assertThat(webServer.requests(), hasSize(1)); + var httpRequest = webServer.requests().get(0); + + // Check that the product use case header was set correctly + assertThat(httpRequest.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + } finally { + // Clean up the thread context + threadPool.getThreadContext().stashContext(); + } + } + } + public void testChunkedInfer_PassesThrough() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var eisGatewayUrl = getUrl(webServer); + var elasticInferenceServiceURL = getUrl(webServer); - try (var service = createService(senderFactory, eisGatewayUrl)) { + try (var service = createService(senderFactory, elasticInferenceServiceURL)) { String responseJson = """ { "data": [ @@ -545,7 +748,7 @@ public void testChunkedInfer_PassesThrough() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(eisGatewayUrl, "my-model-id"); + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, @@ -1070,9 +1273,9 @@ private void testUnifiedStreamError(int responseCode, String responseJson, Strin } private InferenceEventsAssertion testUnifiedStream(int responseCode, String responseJson) throws Exception { - var eisGatewayUrl = getUrl(webServer); + var elasticInferenceServiceURL = getUrl(webServer); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createService(senderFactory, eisGatewayUrl)) { + try (var service = createService(senderFactory, elasticInferenceServiceURL)) { webServer.enqueue(new MockResponse().setResponseCode(responseCode).setBody(responseJson)); var model = new ElasticInferenceServiceCompletionModel( "id", @@ -1081,7 +1284,7 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.of(eisGatewayUrl) + ElasticInferenceServiceComponents.of(elasticInferenceServiceURL) ); PlainActionFuture listener = new PlainActionFuture<>(); service.unifiedCompletionInfer( @@ -1131,14 +1334,14 @@ private ElasticInferenceService createService(HttpRequestSender.Factory senderFa return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), null); } - private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory, String gatewayUrl) { - return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), gatewayUrl); + private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory, String elasticInferenceServiceURL) { + return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), elasticInferenceServiceURL); } private ElasticInferenceService createService( HttpRequestSender.Factory senderFactory, ElasticInferenceServiceAuthorizationModel auth, - String gatewayUrl + String elasticInferenceServiceURL ) { var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); doAnswer(invocation -> { @@ -1150,33 +1353,36 @@ private ElasticInferenceService createService( return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(gatewayUrl), + ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), mockModelRegistry(), mockAuthHandler ); } - private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.Factory senderFactory, String eisGatewayUrl) { + private ElasticInferenceService createServiceWithAuthHandler( + HttpRequestSender.Factory senderFactory, + String elasticInferenceServiceURL + ) { return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(eisGatewayUrl), + ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), mockModelRegistry(), - new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool) ); } public static ElasticInferenceService createServiceWithAuthHandler( HttpRequestSender.Factory senderFactory, - String eisGatewayUrl, + String elasticInferenceServiceURL, ThreadPool threadPool ) { return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(eisGatewayUrl), + ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), mockModelRegistry(threadPool), - new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool) + new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool) ); } }