Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
Expand Down Expand Up @@ -43,6 +44,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast

private final InputType inputType;

private final String productOrigin;

private static ResponseHandler createSparseEmbeddingsHandler() {
return new ElasticInferenceServiceResponseHandler(
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
Expand All @@ -60,6 +63,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequestManager(
this.model = model;
this.truncator = serviceComponents.truncator();
this.traceContext = traceContext;
this.productOrigin = serviceComponents.threadPool().getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
this.inputType = inputType;
}

Expand All @@ -78,6 +82,7 @@ public void execute(
truncatedInput,
model,
traceContext,
productOrigin,
inputType
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
Expand Down Expand Up @@ -43,6 +44,7 @@ public static ElasticInferenceServiceUnifiedCompletionRequestManager of(

private final ElasticInferenceServiceCompletionModel model;
private final TraceContext traceContext;
private final String productOrigin;

private ElasticInferenceServiceUnifiedCompletionRequestManager(
ElasticInferenceServiceCompletionModel model,
Expand All @@ -52,6 +54,7 @@ private ElasticInferenceServiceUnifiedCompletionRequestManager(
super(threadPool, model);
this.model = model;
this.traceContext = traceContext;
this.productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
}

@Override
Expand All @@ -65,7 +68,8 @@ public void execute(
ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest(
inferenceInputs.castTo(UnifiedChatInput.class),
model,
traceContext
traceContext,
productOrigin
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpRequestBase;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
Expand All @@ -20,12 +20,13 @@
import java.net.URISyntaxException;
import java.util.Objects;

public class ElasticInferenceServiceAuthorizationRequest implements ElasticInferenceServiceRequest {
public class ElasticInferenceServiceAuthorizationRequest extends ElasticInferenceServiceRequest {

private final URI uri;
private final TraceContextHandler traceContextHandler;

public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext) {
public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext, String productOrigin) {
super(productOrigin);
this.uri = createUri(Objects.requireNonNull(url));
this.traceContextHandler = new TraceContextHandler(traceContext);
}
Expand All @@ -44,11 +45,11 @@ private URI createUri(String url) throws ElasticsearchStatusException {
}

@Override
public HttpRequest createHttpRequest() {
public HttpRequestBase createHttpRequestBase() {
var httpGet = new HttpGet(uri);
traceContextHandler.propagateTraceContext(httpGet);

return new HttpRequest(httpGet, getInferenceEntityId());
return httpGet;
}

public TraceContext getTraceContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.client.methods.HttpRequestBase;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;

public interface ElasticInferenceServiceRequest extends Request {}
public abstract class ElasticInferenceServiceRequest implements Request {

private final String productOrigin;

public ElasticInferenceServiceRequest(String productOrigin) {
this.productOrigin = productOrigin;
}

public String getProductOrigin() {
return productOrigin;
}

@Override
public final HttpRequest createHttpRequest() {
HttpRequestBase request = createHttpRequestBase();
// TODO: consider moving tracing here, too
request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, productOrigin);
return new HttpRequest(request, getInferenceEntityId());
}

protected abstract HttpRequestBase createHttpRequestBase();
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
Expand All @@ -26,7 +26,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {
public class ElasticInferenceServiceSparseEmbeddingsRequest extends ElasticInferenceServiceRequest {

private final URI uri;
private final ElasticInferenceServiceSparseEmbeddingsModel model;
Expand All @@ -40,8 +40,10 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
Truncator.TruncationResult truncationResult,
ElasticInferenceServiceSparseEmbeddingsModel model,
TraceContext traceContext,
String productOrigin,
InputType inputType
) {
super(productOrigin);
this.truncator = truncator;
this.truncationResult = truncationResult;
this.model = Objects.requireNonNull(model);
Expand All @@ -51,7 +53,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
}

@Override
public HttpRequest createHttpRequest() {
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(uri);
var usageContext = inputTypeToUsageContext(inputType);
var requestEntity = Strings.toString(
Expand All @@ -68,7 +70,7 @@ public HttpRequest createHttpRequest() {
traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
return httpPost;
}

public TraceContext getTraceContext() {
Expand All @@ -93,6 +95,7 @@ public Request truncate() {
truncatedInput,
model,
traceContextHandler.traceContext(),
getProductOrigin(),
inputType
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
Expand All @@ -24,7 +24,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {
public class ElasticInferenceServiceUnifiedChatCompletionRequest extends ElasticInferenceServiceRequest {

private final ElasticInferenceServiceCompletionModel model;
private final UnifiedChatInput unifiedChatInput;
Expand All @@ -33,15 +33,17 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Requ
public ElasticInferenceServiceUnifiedChatCompletionRequest(
UnifiedChatInput unifiedChatInput,
ElasticInferenceServiceCompletionModel model,
TraceContext traceContext
TraceContext traceContext,
String productOrigin
) {
super(productOrigin);
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequest createHttpRequest() {
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(model.uri());
var requestEntity = Strings.toString(
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
Expand All @@ -53,7 +55,7 @@ public HttpRequest createHttpRequest() {
traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
return httpPost;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
requestCompleteLatch.countDown();
});

var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo());
var productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), productOrigin);

sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var request = new ElasticInferenceServiceAuthorizationRequest(getUrl(webServer), new TraceContext("", ""));
var request = new ElasticInferenceServiceAuthorizationRequest(
getUrl(webServer),
new TraceContext("", ""),
randomAlphaOfLength(10)
);
var responseHandler = new ElasticInferenceServiceResponseHandler(
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
ElasticInferenceServiceAuthorizationResponseEntity::fromResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void testCreateUriThrowsForInvalidBaseUrl() {

ElasticsearchStatusException exception = assertThrows(
ElasticsearchStatusException.class,
() -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext)
() -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext, randomAlphaOfLength(10))
);

assertThat(exception.status(), is(RestStatus.BAD_REQUEST));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpRequestBase;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.net.URI;

import static org.hamcrest.Matchers.equalTo;

public class ElasticInferenceServiceRequestTests extends ESTestCase {

public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithProductOrigin() {
var productOrigin = "elastic";
var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest(productOrigin);
var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest();
var productOriginHeader = httpRequest.httpRequestBase().getFirstHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);

// Make sure this header only exists once
assertThat(httpRequest.httpRequestBase().getHeaders(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER).length, equalTo(1));
assertThat(productOriginHeader.getValue(), equalTo(productOrigin));
}

private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest(String productOrigin) {
return new ElasticInferenceServiceRequest(productOrigin) {
@Override
protected HttpRequestBase createHttpRequestBase() {
return new HttpGet("http://localhost:8080");
}

@Override
public URI getURI() {
return null;
}

@Override
public Request truncate() {
return null;
}

@Override
public boolean[] getTruncationInfo() {
return new boolean[0];
}

@Override
public String getInferenceEntityId() {
return "";
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url,
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
embeddingsModel,
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomAlphaOfLength(10),
inputType
);
}
Expand Down