Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/124025.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 124025
summary: "[Inference API] Propagate product use case http header to EIS"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ static TransportVersion def(int id) {
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

/**
* Record for storing context alongside an inference request, typically used for metadata.
* This is mainly used to pass along inference context on the transport layer without relying on
* {@link org.elasticsearch.common.util.concurrent.ThreadContext}, which depending on the internal
* {@link org.elasticsearch.client.internal.Client} throws away parts of the context, when passed along the transport layer.
*
* @param productUseCase - for now mainly used by Elastic Inference Service
*/
public record InferenceContext(String productUseCase) implements Writeable, ToXContent {

public static final InferenceContext EMPTY_INSTANCE = new InferenceContext("");

public InferenceContext {
Objects.requireNonNull(productUseCase);
}

public InferenceContext(StreamInput in) throws IOException {
this(in.readString());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(productUseCase);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field("product_use_case", productUseCase);

builder.endObject();

return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceContext that = (InferenceContext) o;
return Objects.equals(productUseCase, that.productUseCase);
}

@Override
public int hashCode() {
return Objects.hashCode(productUseCase);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.InferenceContext;

import java.io.IOException;
import java.util.Objects;

/**
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops
Expand All @@ -23,8 +25,11 @@ public abstract class BaseInferenceActionRequest extends ActionRequest {

private boolean hasBeenRerouted;

public BaseInferenceActionRequest() {
private final InferenceContext context;

public BaseInferenceActionRequest(InferenceContext context) {
super();
this.context = context;
}

public BaseInferenceActionRequest(StreamInput in) throws IOException {
Expand All @@ -36,6 +41,12 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException {
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
this.hasBeenRerouted = true;
}

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
this.context = new InferenceContext(in);
} else {
this.context = InferenceContext.EMPTY_INSTANCE;
}
}

public abstract boolean isStreaming();
Expand All @@ -52,11 +63,32 @@ public boolean hasBeenRerouted() {
return hasBeenRerouted;
}

public InferenceContext getContext() {
return context;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
out.writeBoolean(hasBeenRerouted);
}

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
context.writeTo(out);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BaseInferenceActionRequest that = (BaseInferenceActionRequest) o;
return hasBeenRerouted == that.hasBeenRerouted && Objects.equals(context, that.context);
}

@Override
public int hashCode() {
return Objects.hash(hasBeenRerouted, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
Expand Down Expand Up @@ -74,12 +75,14 @@ public static class Request extends BaseInferenceActionRequest {
InputType.UNSPECIFIED
);

public static Builder parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {
public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser)
throws IOException {
Request.Builder builder = PARSER.apply(parser, null);
builder.setInferenceEntityId(inferenceEntityId);
builder.setTaskType(taskType);
// For rest requests we won't know what the input type is
builder.setInputType(InputType.UNSPECIFIED);
builder.setContext(context);
return builder;
}

Expand All @@ -102,6 +105,31 @@ public Request(
TimeValue inferenceTimeout,
boolean stream
) {
this(
taskType,
inferenceEntityId,
query,
input,
taskSettings,
inputType,
inferenceTimeout,
stream,
InferenceContext.EMPTY_INSTANCE
);
}

public Request(
TaskType taskType,
String inferenceEntityId,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue inferenceTimeout,
boolean stream,
InferenceContext context
) {
super(context);
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.query = query;
Expand Down Expand Up @@ -241,19 +269,31 @@ static InputType getInputTypeToWrite(InputType inputType, TransportVersion versi
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
Request request = (Request) o;
return taskType == request.taskType
return stream == request.stream
&& taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(query, request.query)
&& Objects.equals(input, request.input)
&& Objects.equals(taskSettings, request.taskSettings)
&& Objects.equals(inputType, request.inputType)
&& Objects.equals(query, request.query)
&& inputType == request.inputType
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout);
return Objects.hash(
super.hashCode(),
taskType,
inferenceEntityId,
query,
input,
taskSettings,
inputType,
inferenceTimeout,
stream
);
}

public static class Builder {
Expand All @@ -266,6 +306,7 @@ public static class Builder {
private String query;
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;
private InferenceContext context;

private Builder() {}

Expand Down Expand Up @@ -313,8 +354,13 @@ public Builder setStream(boolean stream) {
return this;
}

public Builder setContext(InferenceContext context) {
this.context = context;
return this;
}

public Request build() {
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
}
}

Expand All @@ -333,6 +379,8 @@ public String toString() {
+ this.getInputType()
+ ", timeout="
+ this.getInferenceTimeout()
+ ", context="
+ this.getContext()
+ ")";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
Expand All @@ -17,6 +18,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.InferenceContext;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -44,21 +46,24 @@ public static class Request extends ActionRequest {
private final XContentType contentType;
private final TimeValue timeout;
private final boolean stream;
private final InferenceContext context;

public Request(
TaskType taskType,
String inferenceEntityId,
BytesReference content,
XContentType contentType,
TimeValue timeout,
boolean stream
boolean stream,
InferenceContext context
) {
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.timeout = timeout;
this.stream = stream;
this.context = context;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -71,6 +76,12 @@ public Request(StreamInput in) throws IOException {

// streaming is not supported yet for transport traffic
this.stream = false;

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
this.context = new InferenceContext(in);
} else {
this.context = InferenceContext.EMPTY_INSTANCE;
}
}

public TaskType getTaskType() {
Expand All @@ -97,6 +108,10 @@ public boolean isStreaming() {
return stream;
}

public InferenceContext getContext() {
return context;
}

@Override
public ActionRequestValidationException validate() {
return null;
Expand All @@ -110,6 +125,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);
out.writeTimeValue(timeout);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
context.writeTo(out);
}
}

@Override
Expand All @@ -122,12 +141,13 @@ public boolean equals(Object o) {
&& Objects.equals(content, request.content)
&& contentType == request.contentType
&& timeout == request.timeout
&& stream == request.stream;
&& stream == request.stream
&& context == request.context;
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream);
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream, context);
}
}
}
Loading