Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/124025.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 124025
summary: "[Inference API] Propagate product use case http header to EIS"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ static TransportVersion def(int id) {
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

/**
* Record for storing context alongside an inference request, typically used for metadata.
* This is mainly used to pass along inference context on the transport layer without relying on
* {@link org.elasticsearch.common.util.concurrent.ThreadContext}, which depending on the internal
* {@link org.elasticsearch.client.internal.Client} throws away parts of the context, when passed along the transport layer.
*
* @param productUseCase - for now mainly used by Elastic Inference Service
*/
public record InferenceContext(String productUseCase) implements Writeable, ToXContent {

public static final InferenceContext EMPTY_INSTANCE = new InferenceContext("");

public InferenceContext {
Objects.requireNonNull(productUseCase);
}

public InferenceContext(StreamInput in) throws IOException {
this(in.readString());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(productUseCase);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field("product_use_case", productUseCase);

builder.endObject();

return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceContext that = (InferenceContext) o;
return Objects.equals(productUseCase, that.productUseCase);
}

@Override
public int hashCode() {
return Objects.hashCode(productUseCase);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.InferenceContext;

import java.io.IOException;
import java.util.Objects;

/**
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops
Expand All @@ -23,8 +25,11 @@ public abstract class BaseInferenceActionRequest extends ActionRequest {

private boolean hasBeenRerouted;

public BaseInferenceActionRequest() {
private final InferenceContext context;

public BaseInferenceActionRequest(InferenceContext context) {
super();
this.context = context;
}

public BaseInferenceActionRequest(StreamInput in) throws IOException {
Expand All @@ -36,6 +41,12 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException {
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
this.hasBeenRerouted = true;
}

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
this.context = new InferenceContext(in);
} else {
this.context = InferenceContext.EMPTY_INSTANCE;
}
}

public abstract boolean isStreaming();
Expand All @@ -52,11 +63,32 @@ public boolean hasBeenRerouted() {
return hasBeenRerouted;
}

public InferenceContext getContext() {
return context;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
out.writeBoolean(hasBeenRerouted);
}

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
context.writeTo(out);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BaseInferenceActionRequest that = (BaseInferenceActionRequest) o;
return hasBeenRerouted == that.hasBeenRerouted && Objects.equals(context, that.context);
}

@Override
public int hashCode() {
return Objects.hash(hasBeenRerouted, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
Expand Down Expand Up @@ -74,12 +75,14 @@ public static class Request extends BaseInferenceActionRequest {
InputType.UNSPECIFIED
);

public static Builder parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {
public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser)
throws IOException {
Request.Builder builder = PARSER.apply(parser, null);
builder.setInferenceEntityId(inferenceEntityId);
builder.setTaskType(taskType);
// For rest requests we won't know what the input type is
builder.setInputType(InputType.UNSPECIFIED);
builder.setContext(context);
return builder;
}

Expand All @@ -102,6 +105,31 @@ public Request(
TimeValue inferenceTimeout,
boolean stream
) {
this(
taskType,
inferenceEntityId,
query,
input,
taskSettings,
inputType,
inferenceTimeout,
stream,
InferenceContext.EMPTY_INSTANCE
);
}

public Request(
TaskType taskType,
String inferenceEntityId,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue inferenceTimeout,
boolean stream,
InferenceContext context
) {
super(context);
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.query = query;
Expand Down Expand Up @@ -241,19 +269,31 @@ static InputType getInputTypeToWrite(InputType inputType, TransportVersion versi
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
Request request = (Request) o;
return taskType == request.taskType
return stream == request.stream
&& taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(query, request.query)
&& Objects.equals(input, request.input)
&& Objects.equals(taskSettings, request.taskSettings)
&& Objects.equals(inputType, request.inputType)
&& Objects.equals(query, request.query)
&& inputType == request.inputType
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout);
return Objects.hash(
super.hashCode(),
taskType,
inferenceEntityId,
query,
input,
taskSettings,
inputType,
inferenceTimeout,
stream
);
}

public static class Builder {
Expand All @@ -266,6 +306,7 @@ public static class Builder {
private String query;
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;
private InferenceContext context;

private Builder() {}

Expand Down Expand Up @@ -313,8 +354,13 @@ public Builder setStream(boolean stream) {
return this;
}

public Builder setContext(InferenceContext context) {
this.context = context;
return this;
}

public Request build() {
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
}
}

Expand All @@ -333,6 +379,8 @@ public String toString() {
+ this.getInputType()
+ ", timeout="
+ this.getInferenceTimeout()
+ ", context="
+ this.getContext()
+ ")";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
Expand All @@ -17,6 +18,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.InferenceContext;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -44,21 +46,24 @@ public static class Request extends ActionRequest {
private final XContentType contentType;
private final TimeValue timeout;
private final boolean stream;
private final InferenceContext context;

public Request(
TaskType taskType,
String inferenceEntityId,
BytesReference content,
XContentType contentType,
TimeValue timeout,
boolean stream
boolean stream,
InferenceContext context
) {
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.timeout = timeout;
this.stream = stream;
this.context = context;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -71,6 +76,12 @@ public Request(StreamInput in) throws IOException {

// streaming is not supported yet for transport traffic
this.stream = false;

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
this.context = new InferenceContext(in);
} else {
this.context = InferenceContext.EMPTY_INSTANCE;
}
}

public TaskType getTaskType() {
Expand All @@ -97,6 +108,10 @@ public boolean isStreaming() {
return stream;
}

public InferenceContext getContext() {
return context;
}

@Override
public ActionRequestValidationException validate() {
return null;
Expand All @@ -110,6 +125,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);
out.writeTimeValue(timeout);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
context.writeTo(out);
}
}

@Override
Expand All @@ -122,12 +141,13 @@ public boolean equals(Object o) {
&& Objects.equals(content, request.content)
&& contentType == request.contentType
&& timeout == request.timeout
&& stream == request.stream;
&& stream == request.stream
&& context == request.context;
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream);
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream, context);
}
}
}
Loading