Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
Expand Down Expand Up @@ -43,6 +44,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast

private final InputType inputType;

private final String productOrigin;

private static ResponseHandler createSparseEmbeddingsHandler() {
return new ElasticInferenceServiceResponseHandler(
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
Expand All @@ -60,6 +63,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequestManager(
this.model = model;
this.truncator = serviceComponents.truncator();
this.traceContext = traceContext;
this.productOrigin = serviceComponents.threadPool().getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
this.inputType = inputType;
}

Expand All @@ -78,6 +82,7 @@ public void execute(
truncatedInput,
model,
traceContext,
productOrigin,
inputType
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
Expand Down Expand Up @@ -43,6 +44,7 @@ public static ElasticInferenceServiceUnifiedCompletionRequestManager of(

private final ElasticInferenceServiceCompletionModel model;
private final TraceContext traceContext;
private final String productOrigin;

private ElasticInferenceServiceUnifiedCompletionRequestManager(
ElasticInferenceServiceCompletionModel model,
Expand All @@ -52,6 +54,7 @@ private ElasticInferenceServiceUnifiedCompletionRequestManager(
super(threadPool, model);
this.model = model;
this.traceContext = traceContext;
this.productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
}

@Override
Expand All @@ -65,7 +68,8 @@ public void execute(
ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest(
inferenceInputs.castTo(UnifiedChatInput.class),
model,
traceContext
traceContext,
productOrigin
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpRequestBase;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
Expand All @@ -20,12 +20,13 @@
import java.net.URISyntaxException;
import java.util.Objects;

public class ElasticInferenceServiceAuthorizationRequest implements ElasticInferenceServiceRequest {
public class ElasticInferenceServiceAuthorizationRequest extends ElasticInferenceServiceRequest {

private final URI uri;
private final TraceContextHandler traceContextHandler;

public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext) {
public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext, String productOrigin) {
super(productOrigin);
this.uri = createUri(Objects.requireNonNull(url));
this.traceContextHandler = new TraceContextHandler(traceContext);
}
Expand All @@ -44,11 +45,11 @@ private URI createUri(String url) throws ElasticsearchStatusException {
}

@Override
public HttpRequest createHttpRequest() {
public HttpRequestBase createHttpRequestBase() {
var httpGet = new HttpGet(uri);
traceContextHandler.propagateTraceContext(httpGet);

return new HttpRequest(httpGet, getInferenceEntityId());
return httpGet;
}

public TraceContext getTraceContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.client.methods.HttpRequestBase;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;

public interface ElasticInferenceServiceRequest extends Request {}
public abstract class ElasticInferenceServiceRequest implements Request {

private final String productOrigin;

public ElasticInferenceServiceRequest(String productOrigin) {
this.productOrigin = productOrigin;
}

public String getProductOrigin() {
return productOrigin;
}

@Override
public final HttpRequest createHttpRequest() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made this final so it's clear that a new subclass of ElasticInferenceServiceRequest should override createHttpRequestBase and not createHttpRequest.

HttpRequestBase request = createHttpRequestBase();
// TODO: consider moving tracing here, too
Copy link
Contributor Author

@timgrein timgrein Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to keep it out of scope for this PR, but I think it also makes sense to move the tracing logic here, so we don't risk forgetting it for new subclasses of ElasticInferenceServiceRequest.

request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, productOrigin);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is productOrigin null if it isn't present in the inbound header? I can't tell if defaultHeaders would include a default productOrigin, but if it doesn't we may want to either omit the header or send something like "Unknown"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't really need to care as EIS handles the case, if this header is absent and/or empty and sets unknown if so. I would like to simply treat this as a forwarding logic without looking into the value.

return new HttpRequest(request, getInferenceEntityId());
}

protected abstract HttpRequestBase createHttpRequestBase();
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
Expand All @@ -26,7 +26,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {
public class ElasticInferenceServiceSparseEmbeddingsRequest extends ElasticInferenceServiceRequest {

private final URI uri;
private final ElasticInferenceServiceSparseEmbeddingsModel model;
Expand All @@ -40,8 +40,10 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
Truncator.TruncationResult truncationResult,
ElasticInferenceServiceSparseEmbeddingsModel model,
TraceContext traceContext,
String productOrigin,
InputType inputType
) {
super(productOrigin);
this.truncator = truncator;
this.truncationResult = truncationResult;
this.model = Objects.requireNonNull(model);
Expand All @@ -51,7 +53,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
}

@Override
public HttpRequest createHttpRequest() {
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(uri);
var usageContext = inputTypeToUsageContext(inputType);
var requestEntity = Strings.toString(
Expand All @@ -68,7 +70,7 @@ public HttpRequest createHttpRequest() {
traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
return httpPost;
}

public TraceContext getTraceContext() {
Expand All @@ -93,6 +95,7 @@ public Request truncate() {
truncatedInput,
model,
traceContextHandler.traceContext(),
getProductOrigin(),
inputType
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
Expand All @@ -24,7 +24,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {
public class ElasticInferenceServiceUnifiedChatCompletionRequest extends ElasticInferenceServiceRequest {

private final ElasticInferenceServiceCompletionModel model;
private final UnifiedChatInput unifiedChatInput;
Expand All @@ -33,15 +33,17 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Requ
public ElasticInferenceServiceUnifiedChatCompletionRequest(
UnifiedChatInput unifiedChatInput,
ElasticInferenceServiceCompletionModel model,
TraceContext traceContext
TraceContext traceContext,
String productOrigin
) {
super(productOrigin);
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequest createHttpRequest() {
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(model.uri());
var requestEntity = Strings.toString(
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
Expand All @@ -53,7 +55,7 @@ public HttpRequest createHttpRequest() {
traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
return httpPost;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
requestCompleteLatch.countDown();
});

var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo());
var productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), productOrigin);

sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_TIMEOUT, newListener);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var request = new ElasticInferenceServiceAuthorizationRequest(getUrl(webServer), new TraceContext("", ""));
var request = new ElasticInferenceServiceAuthorizationRequest(
getUrl(webServer),
new TraceContext("", ""),
randomAlphaOfLength(10)
);
var responseHandler = new ElasticInferenceServiceResponseHandler(
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
ElasticInferenceServiceAuthorizationResponseEntity::fromResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void testCreateUriThrowsForInvalidBaseUrl() {

ElasticsearchStatusException exception = assertThrows(
ElasticsearchStatusException.class,
() -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext)
() -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext, randomAlphaOfLength(10))
);

assertThat(exception.status(), is(RestStatus.BAD_REQUEST));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpRequestBase;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.net.URI;

import static org.hamcrest.Matchers.equalTo;

public class ElasticInferenceServiceRequestTests extends ESTestCase {

public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithProductOrigin() {
var productOrigin = "elastic";
var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest(productOrigin);
var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest();
var productOriginHeader = httpRequest.httpRequestBase().getFirstHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);

// Make sure this header only exists one time
assertThat(httpRequest.httpRequestBase().getHeaders(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER).length, equalTo(1));
assertThat(productOriginHeader.getValue(), equalTo(productOrigin));
}

private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest(String productOrigin) {
return new ElasticInferenceServiceRequest(productOrigin) {
@Override
protected HttpRequestBase createHttpRequestBase() {
return new HttpGet("http://localhost:8080");
}

@Override
public URI getURI() {
return null;
}

@Override
public Request truncate() {
return null;
}

@Override
public boolean[] getTruncationInfo() {
return new boolean[0];
}

@Override
public String getInferenceEntityId() {
return "";
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url,
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
embeddingsModel,
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomAlphaOfLength(10),
inputType
);
}
Expand Down