Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ab712dd
WIP
timgrein Mar 4, 2025
b617e0a
[CI] Auto commit changes from spotless
Mar 4, 2025
2e8231a
Iterate (propagate InferenceContext & put HTTP header into thread con…
timgrein Mar 6, 2025
ca79955
Iter (product use case propagation works)
timgrein Mar 6, 2025
2c12778
Iter (move request metadata extraction to parent request manager class)
timgrein Mar 6, 2025
cc52207
[CI] Auto commit changes from spotless
Mar 6, 2025
a1da2f6
Add docs to InferenceContext
timgrein Mar 6, 2025
a6e9297
Merge remote-tracking branch 'origin/read-and-propagate-product-use-c…
timgrein Mar 6, 2025
82ce713
Update InferenceContext.java
timgrein Mar 6, 2025
e5c5933
Remove duplicate context from InferenceAction.Request
timgrein Mar 7, 2025
634aaef
Merge remote-tracking branch 'origin/read-and-propagate-product-use-c…
timgrein Mar 7, 2025
0302468
Add InferenceContextTests
timgrein Mar 7, 2025
ef35491
Add additional test case for context in InferenceActionRequestTests
timgrein Mar 7, 2025
3abe303
Add new test cases to InferenceActionRequestTests
timgrein Mar 7, 2025
517c943
Add new test cases to UnifiedCompletionActionRequestTests
timgrein Mar 7, 2025
0bc06df
Add test to verify that the header is set in the thread context
timgrein Mar 7, 2025
7a515bc
Remove TODO
timgrein Mar 7, 2025
cf55cb5
Add product use case header extraction test cases to BaseInferenceAct…
timgrein Mar 7, 2025
8d44488
Remove addressed TODO and spotlessApply
timgrein Mar 7, 2025
139f7a5
Add product use case propagation tests in ElasticInferenceServiceTests
timgrein Mar 7, 2025
280f0cf
Merge branch 'main' into read-and-propagate-product-use-case-header-t…
timgrein Mar 7, 2025
0a3b0a0
Fix compilation error
timgrein Mar 7, 2025
cd1beee
Update docs/changelog/124025.yaml
timgrein Mar 7, 2025
1c0a3c4
Replace InferenceContext.empty() with InferenceContext.EMPTY_INSTANCE
timgrein Mar 10, 2025
d6a7bb9
Merge remote-tracking branch 'origin/read-and-propagate-product-use-c…
timgrein Mar 10, 2025
89a973a
Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/…
timgrein Mar 10, 2025
2f881e0
Use .get(0) instead of getFirst to avoid compilation errors in backport.
timgrein Mar 10, 2025
e5153a1
Add TransportVersion for 8_X
timgrein Mar 10, 2025
149a094
Add TODO to remove temporary product use case propagation
timgrein Mar 10, 2025
69687d5
Add comment with rationale explaining difference in header extraction
timgrein Mar 10, 2025
bc1a74f
Ensure that productUseCase field in InferenceContext is non-null
timgrein Mar 10, 2025
8ab08f9
Merge branch 'main' into read-and-propagate-product-use-case-header-t…
timgrein Mar 11, 2025
e34359c
Merge branch 'main' into read-and-propagate-product-use-case-header-t…
timgrein Mar 11, 2025
a7d0f27
spotlessApply
timgrein Mar 11, 2025
e494a39
fix checkstyle errors in xpack core plugin
timgrein Mar 11, 2025
5b90efb
Fix checkstyle errors in inference plugin
timgrein Mar 11, 2025
418241f
[CI] Auto commit changes from spotless
Mar 11, 2025
9163207
Fix test in InferenceActionRequestTests and adapt the structure to be…
timgrein Mar 12, 2025
5d4b92f
Merge remote-tracking branch 'origin/read-and-propagate-product-use-c…
timgrein Mar 12, 2025
fa47c8c
Add equals/hashCode to InferenceContext
timgrein Mar 12, 2025
0515e1a
Merge branch 'main' into read-and-propagate-product-use-case-header-t…
timgrein Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/124025.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 124025
summary: "[Inference API] Propagate product use case http header to EIS"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INCLUDE_INDEX_MODE_IN_GET_DATA_STREAM = def(9_023_0_00);
public static final TransportVersion MAX_OPERATION_SIZE_REJECTIONS_ADDED = def(9_024_0_00);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR = def(9_025_0_00);
public static final TransportVersion INFERENCE_CONTEXT = def(9_026_0_00);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a reminder, if we do want to backport to 8.19 we'll need a TransportVersion for 8.x

for example: COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X

We'll also need to change the onAfter() check. Here's an example:
https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java#L131-L132

The code in 8.x will look different too (since the 9.x transport version won't exist): https://github.com/elastic/elasticsearch/blob/8.x/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java#L131

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation and the code examples.

Adjusted with Add TransportVersion for 8_X.

In the backport I would then need to replace TransportVersions.INFERENCE_CONTEXT with TransportVersions.INFERENCE_CONTEXT_8_X, right?


/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;

/**
* Record for storing context alongside an inference request, typically used for metadata.
* This is mainly used to pass along inference context on the transport layer without relying on {@link org.elasticsearch.common.util.concurrent.ThreadContext},
* which depending on the internal {@link org.elasticsearch.client.internal.Client} throws away parts of the context, when passed along the transport layer.
*
* @param productUseCase - for now mainly interesting to Elastic Inference Service
*/
public record InferenceContext(String productUseCase) implements Writeable, ToXContent {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should verify non-null so the transport layer doesn't hate us for not using writeOptionalString:

public InferenceContext {
    Objects.requireNonNull(productUseCase);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


public InferenceContext(StreamInput in) throws IOException {
this(in.readString());
}

public static InferenceContext empty() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we create a static instance that way we don't create multiple empty ones? Something like this:

public static final InferenceContext EMPTY_INSTANCE = new InferenceContext("");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return new InferenceContext("");
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(productUseCase);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field("product_use_case", productUseCase);

builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.InferenceContext;

import java.io.IOException;
import java.util.Objects;

/**
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops
Expand All @@ -23,8 +25,11 @@ public abstract class BaseInferenceActionRequest extends ActionRequest {

private boolean hasBeenRerouted;

public BaseInferenceActionRequest() {
private final InferenceContext context;

public BaseInferenceActionRequest(InferenceContext context) {
super();
this.context = context;
}

public BaseInferenceActionRequest(StreamInput in) throws IOException {
Expand All @@ -36,6 +41,12 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException {
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
this.hasBeenRerouted = true;
}

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)) {
this.context = new InferenceContext(in);
} else {
this.context = InferenceContext.empty();
}
}

public abstract boolean isStreaming();
Expand All @@ -52,11 +63,32 @@ public boolean hasBeenRerouted() {
return hasBeenRerouted;
}

public InferenceContext getContext() {
return context;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
out.writeBoolean(hasBeenRerouted);
}

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)) {
context.writeTo(out);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BaseInferenceActionRequest that = (BaseInferenceActionRequest) o;
return hasBeenRerouted == that.hasBeenRerouted && Objects.equals(context, that.context);
}

@Override
public int hashCode() {
return Objects.hash(hasBeenRerouted, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
Expand Down Expand Up @@ -74,12 +75,14 @@ public static class Request extends BaseInferenceActionRequest {
InputType.UNSPECIFIED
);

public static Builder parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {
public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser)
throws IOException {
Request.Builder builder = PARSER.apply(parser, null);
builder.setInferenceEntityId(inferenceEntityId);
builder.setTaskType(taskType);
// For rest requests we won't know what the input type is
builder.setInputType(InputType.UNSPECIFIED);
builder.setContext(context);
return builder;
}

Expand All @@ -102,6 +105,21 @@ public Request(
TimeValue inferenceTimeout,
boolean stream
) {
this(taskType, inferenceEntityId, query, input, taskSettings, inputType, inferenceTimeout, stream, InferenceContext.empty());
}

public Request(
TaskType taskType,
String inferenceEntityId,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue inferenceTimeout,
boolean stream,
InferenceContext context
) {
super(context);
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.query = query;
Expand Down Expand Up @@ -241,19 +259,31 @@ static InputType getInputTypeToWrite(InputType inputType, TransportVersion versi
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
Request request = (Request) o;
return taskType == request.taskType
return stream == request.stream
&& taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(query, request.query)
&& Objects.equals(input, request.input)
&& Objects.equals(taskSettings, request.taskSettings)
&& Objects.equals(inputType, request.inputType)
&& Objects.equals(query, request.query)
&& inputType == request.inputType
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout);
return Objects.hash(
super.hashCode(),
taskType,
inferenceEntityId,
query,
input,
taskSettings,
inputType,
inferenceTimeout,
stream
);
}

public static class Builder {
Expand All @@ -266,6 +296,7 @@ public static class Builder {
private String query;
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;
private InferenceContext context;

private Builder() {}

Expand Down Expand Up @@ -313,8 +344,13 @@ public Builder setStream(boolean stream) {
return this;
}

public Builder setContext(InferenceContext context) {
this.context = context;
return this;
}

public Request build() {
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
}
}

Expand All @@ -333,6 +369,8 @@ public String toString() {
+ this.getInputType()
+ ", timeout="
+ this.getInferenceTimeout()
+ ", context="
+ this.getContext()
+ ")";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
Expand All @@ -17,6 +18,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.InferenceContext;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -44,21 +46,24 @@ public static class Request extends ActionRequest {
private final XContentType contentType;
private final TimeValue timeout;
private final boolean stream;
private final InferenceContext context;

public Request(
TaskType taskType,
String inferenceEntityId,
BytesReference content,
XContentType contentType,
TimeValue timeout,
boolean stream
boolean stream,
InferenceContext context
) {
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.timeout = timeout;
this.stream = stream;
this.context = context;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -71,6 +76,12 @@ public Request(StreamInput in) throws IOException {

// streaming is not supported yet for transport traffic
this.stream = false;

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)) {
this.context = new InferenceContext(in);
} else {
this.context = InferenceContext.empty();
}
}

public TaskType getTaskType() {
Expand All @@ -97,6 +108,10 @@ public boolean isStreaming() {
return stream;
}

public InferenceContext getContext() {
return context;
}

@Override
public ActionRequestValidationException validate() {
return null;
Expand All @@ -110,6 +125,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);
out.writeTimeValue(timeout);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)) {
context.writeTo(out);
}
}

@Override
Expand All @@ -122,12 +141,13 @@ public boolean equals(Object o) {
&& Objects.equals(content, request.content)
&& contentType == request.contentType
&& timeout == request.timeout
&& stream == request.stream;
&& stream == request.stream
&& context == request.context;
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream);
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream, context);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.InferenceContext;

import java.io.IOException;
import java.util.Objects;
Expand All @@ -28,10 +29,15 @@ public UnifiedCompletionAction() {
}

public static class Request extends BaseInferenceActionRequest {
public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser)
throws IOException {
public static Request parseRequest(
String inferenceEntityId,
TaskType taskType,
TimeValue timeout,
InferenceContext context,
XContentParser parser
) throws IOException {
var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null);
return new Request(inferenceEntityId, taskType, unifiedRequest, timeout);
return new Request(inferenceEntityId, taskType, unifiedRequest, context, timeout);
}

private final String inferenceEntityId;
Expand All @@ -40,6 +46,17 @@ public static Request parseRequest(String inferenceEntityId, TaskType taskType,
private final TimeValue timeout;

public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest, TimeValue timeout) {
this(inferenceEntityId, taskType, unifiedCompletionRequest, InferenceContext.empty(), timeout);
}

public Request(
String inferenceEntityId,
TaskType taskType,
UnifiedCompletionRequest unifiedCompletionRequest,
InferenceContext context,
TimeValue timeout
) {
super(context);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest);
Expand Down
Loading