From 77325e578f9cfeb9c82a37d3fa7a00789b1aa938 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Fri, 21 Feb 2025 14:00:35 +0100 Subject: [PATCH 1/2] Auto-propagate product origin for every subclass of ElasticInferenceServiceRequest --- ...ServiceSparseEmbeddingsRequestManager.java | 5 ++ ...erviceUnifiedCompletionRequestManager.java | 6 +- ...cInferenceServiceAuthorizationRequest.java | 11 ++-- .../ElasticInferenceServiceRequest.java | 26 +++++++- ...ferenceServiceSparseEmbeddingsRequest.java | 11 ++-- ...ceServiceUnifiedChatCompletionRequest.java | 12 ++-- ...cInferenceServiceAuthorizationHandler.java | 3 +- .../http/sender/HttpRequestSenderTests.java | 6 +- ...renceServiceAuthorizationRequestTests.java | 2 +- .../ElasticInferenceServiceRequestTests.java | 61 +++++++++++++++++++ ...ceServiceSparseEmbeddingsRequestTests.java | 1 + 11 files changed, 125 insertions(+), 19 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index 693a7ca36785c..c647b3aea4771 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; +import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; @@ -43,6 +44,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast private final InputType inputType; + private final String productOrigin; + private static ResponseHandler createSparseEmbeddingsHandler() { return new ElasticInferenceServiceResponseHandler( String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), @@ -60,6 +63,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequestManager( this.model = model; this.truncator = serviceComponents.truncator(); this.traceContext = traceContext; + this.productOrigin = serviceComponents.threadPool().getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); this.inputType = inputType; } @@ -78,6 +82,7 @@ public void execute( truncatedInput, model, traceContext, + productOrigin, inputType ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java index 66314db1e05bd..226f00f56c7f4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; @@ -43,6 +44,7 @@ public static ElasticInferenceServiceUnifiedCompletionRequestManager of( private final ElasticInferenceServiceCompletionModel model; private final TraceContext traceContext; + private final String productOrigin; private ElasticInferenceServiceUnifiedCompletionRequestManager( ElasticInferenceServiceCompletionModel model, @@ -52,6 +54,7 @@ private ElasticInferenceServiceUnifiedCompletionRequestManager( super(threadPool, model); this.model = model; this.traceContext = traceContext; + this.productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); } @Override @@ -65,7 +68,8 @@ public void execute( ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest( inferenceInputs.castTo(UnifiedChatInput.class), model, - traceContext + traceContext, + productOrigin ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java index d46313755be00..f8dfbd1587b2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.external.request.elastic; import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpRequestBase; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -20,12 +20,13 @@ import java.net.URISyntaxException; import java.util.Objects; -public class ElasticInferenceServiceAuthorizationRequest implements ElasticInferenceServiceRequest { +public class ElasticInferenceServiceAuthorizationRequest extends ElasticInferenceServiceRequest { private final URI uri; private final TraceContextHandler traceContextHandler; - public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext) { + public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext, String productOrigin) { + super(productOrigin); this.uri = createUri(Objects.requireNonNull(url)); this.traceContextHandler = new TraceContextHandler(traceContext); } @@ -44,11 +45,11 @@ private URI createUri(String url) throws ElasticsearchStatusException { } @Override - public HttpRequest createHttpRequest() { + public HttpRequestBase createHttpRequestBase() { var httpGet = new HttpGet(uri); traceContextHandler.propagateTraceContext(httpGet); - return new HttpRequest(httpGet, getInferenceEntityId()); + return httpGet; } public TraceContext getTraceContext() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java index 03eec913a265f..cd152751499b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java @@ -7,6 +7,30 @@ package org.elasticsearch.xpack.inference.external.request.elastic; +import org.apache.http.client.methods.HttpRequestBase; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; -public interface ElasticInferenceServiceRequest extends Request {} +public abstract class ElasticInferenceServiceRequest implements Request { + + private final String productOrigin; + + public ElasticInferenceServiceRequest(String productOrigin) { + this.productOrigin = productOrigin; + } + + public String getProductOrigin() { + return productOrigin; + } + + @Override + public final HttpRequest createHttpRequest() { + HttpRequestBase request = createHttpRequestBase(); + // TODO: consider moving tracing here, too + request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, productOrigin); + return new HttpRequest(request, getInferenceEntityId()); + } + + protected abstract HttpRequestBase createHttpRequestBase(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java index 18fc7d9f8c32d..af44a6379f961 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java @@ -9,13 +9,13 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.entity.ByteArrayEntity; import org.apache.http.message.BasicHeader; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; @@ -26,7 +26,7 @@ import java.nio.charset.StandardCharsets; import java.util.Objects; -public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest { +public class ElasticInferenceServiceSparseEmbeddingsRequest extends ElasticInferenceServiceRequest { private final URI uri; private final ElasticInferenceServiceSparseEmbeddingsModel model; @@ -40,8 +40,10 @@ public ElasticInferenceServiceSparseEmbeddingsRequest( Truncator.TruncationResult truncationResult, ElasticInferenceServiceSparseEmbeddingsModel model, TraceContext traceContext, + String productOrigin, InputType inputType ) { + super(productOrigin); this.truncator = truncator; this.truncationResult = truncationResult; this.model = Objects.requireNonNull(model); @@ -51,7 +53,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest( } @Override - public HttpRequest createHttpRequest() { + public HttpRequestBase createHttpRequestBase() { var httpPost = new HttpPost(uri); var usageContext = inputTypeToUsageContext(inputType); var requestEntity = Strings.toString( @@ -68,7 +70,7 @@ public HttpRequest createHttpRequest() { traceContextHandler.propagateTraceContext(httpPost); httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); - return new HttpRequest(httpPost, getInferenceEntityId()); + return httpPost; } public TraceContext getTraceContext() { @@ -93,6 +95,7 @@ public Request truncate() { truncatedInput, model, traceContextHandler.traceContext(), + getProductOrigin(), inputType ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java index 112ead7057933..6610b1f38a4dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java @@ -9,12 +9,12 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.entity.ByteArrayEntity; import org.apache.http.message.BasicHeader; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets; import java.util.Objects; -public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request { +public class ElasticInferenceServiceUnifiedChatCompletionRequest extends ElasticInferenceServiceRequest { private final ElasticInferenceServiceCompletionModel model; private final UnifiedChatInput unifiedChatInput; @@ -33,15 +33,17 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Requ public ElasticInferenceServiceUnifiedChatCompletionRequest( UnifiedChatInput unifiedChatInput, ElasticInferenceServiceCompletionModel model, - TraceContext traceContext + TraceContext traceContext, + String productOrigin ) { + super(productOrigin); this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.model = Objects.requireNonNull(model); this.traceContextHandler = new TraceContextHandler(traceContext); } @Override - public HttpRequest createHttpRequest() { + public HttpRequestBase createHttpRequestBase() { var httpPost = new HttpPost(model.uri()); var requestEntity = Strings.toString( new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()) @@ -53,7 +55,7 @@ public HttpRequest createHttpRequest() { traceContextHandler.propagateTraceContext(httpPost); httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); - return new HttpRequest(httpPost, getInferenceEntityId()); + return httpPost; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index 4061e78c31dc4..f266d6fb9c485 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -108,7 +108,8 @@ public void getAuthorization(ActionListener listener = new PlainActionFuture<>(); - var request = new ElasticInferenceServiceAuthorizationRequest(getUrl(webServer), new TraceContext("", "")); + var request = new ElasticInferenceServiceAuthorizationRequest( + getUrl(webServer), + new TraceContext("", ""), + randomAlphaOfLength(10) + ); var responseHandler = new ElasticInferenceServiceResponseHandler( String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER), ElasticInferenceServiceAuthorizationResponseEntity::fromResponse diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java index 66819e10c55ba..bc79881505167 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java @@ -30,7 +30,7 @@ public void testCreateUriThrowsForInvalidBaseUrl() { ElasticsearchStatusException exception = assertThrows( ElasticsearchStatusException.class, - () -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext) + () -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext, randomAlphaOfLength(10)) ); assertThat(exception.status(), is(RestStatus.BAD_REQUEST)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java new file mode 100644 index 0000000000000..49be3a6d43853 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpRequestBase; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.net.URI; + +import static org.hamcrest.Matchers.equalTo; + +public class ElasticInferenceServiceRequestTests extends ESTestCase { + + public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithProductOrigin() { + var productOrigin = "elastic"; + var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest(productOrigin); + var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest(); + var productOriginHeader = httpRequest.httpRequestBase().getFirstHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); + + // Make sure this header only exists one time + assertThat(httpRequest.httpRequestBase().getHeaders(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER).length, equalTo(1)); + assertThat(productOriginHeader.getValue(), equalTo(productOrigin)); + } + + private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest(String productOrigin) { + return new ElasticInferenceServiceRequest(productOrigin) { + @Override + protected HttpRequestBase createHttpRequestBase() { + return new HttpGet("http://localhost:8080"); + } + + @Override + public URI getURI() { + return null; + } + + @Override + public Request truncate() { + return null; + } + + @Override + public boolean[] getTruncationInfo() { + return new boolean[0]; + } + + @Override + public String getInferenceEntityId() { + return ""; + } + }; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java index 9211b55236b10..d28a05d356ff8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java @@ -123,6 +123,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, new Truncator.TruncationResult(List.of(input), new boolean[] { false }), embeddingsModel, new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomAlphaOfLength(10), inputType ); } From 82d6754202699e1213f1aac0daaa3a83e1d65839 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Fri, 21 Feb 2025 14:12:07 +0100 Subject: [PATCH 2/2] Correct comment --- .../request/elastic/ElasticInferenceServiceRequestTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java index 49be3a6d43853..c5dd19a045390 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java @@ -25,7 +25,7 @@ public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_Wi var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest(); var productOriginHeader = httpRequest.httpRequestBase().getFirstHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER); - // Make sure this header only exists one time + // Make sure this header only exists once assertThat(httpRequest.httpRequestBase().getHeaders(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER).length, equalTo(1)); assertThat(productOriginHeader.getValue(), equalTo(productOrigin)); }