Skip to content

Commit 0b83425

Browse files
authored
[Inference API] Propagate product use case http header to EIS (#124025)
1 parent 43eee87 commit 0b83425

File tree

33 files changed

+886
-96
lines changed

33 files changed

+886
-96
lines changed

docs/changelog/124025.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 124025
2+
summary: "[Inference API] Propagate product use case http header to EIS"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ static TransportVersion def(int id) {
146146
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
147147
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
148148
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
149+
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
149150
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
150151
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
151152
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -181,6 +182,7 @@ static TransportVersion def(int id) {
181182
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR = def(9_025_0_00);
182183
public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(9_026_0_00);
183184
public static final TransportVersion ESQL_THREAD_NAME_IN_DRIVER_PROFILE = def(9_027_0_00);
185+
public static final TransportVersion INFERENCE_CONTEXT = def(9_028_0_00);
184186

185187
/*
186188
* STOP! READ THIS FIRST! No, really,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.xcontent.ToXContent;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
16+
import java.io.IOException;
17+
import java.util.Objects;
18+
19+
/**
20+
* Record for storing context alongside an inference request, typically used for metadata.
21+
* This is mainly used to pass along inference context on the transport layer without relying on
22+
* {@link org.elasticsearch.common.util.concurrent.ThreadContext}, which depending on the internal
23+
* {@link org.elasticsearch.client.internal.Client} throws away parts of the context, when passed along the transport layer.
24+
*
25+
* @param productUseCase - for now mainly used by Elastic Inference Service
26+
*/
27+
public record InferenceContext(String productUseCase) implements Writeable, ToXContent {
28+
29+
public static final InferenceContext EMPTY_INSTANCE = new InferenceContext("");
30+
31+
public InferenceContext {
32+
Objects.requireNonNull(productUseCase);
33+
}
34+
35+
public InferenceContext(StreamInput in) throws IOException {
36+
this(in.readString());
37+
}
38+
39+
@Override
40+
public void writeTo(StreamOutput out) throws IOException {
41+
out.writeString(productUseCase);
42+
}
43+
44+
@Override
45+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
46+
builder.startObject();
47+
48+
builder.field("product_use_case", productUseCase);
49+
50+
builder.endObject();
51+
52+
return builder;
53+
}
54+
55+
@Override
56+
public boolean equals(Object o) {
57+
if (this == o) return true;
58+
if (o == null || getClass() != o.getClass()) return false;
59+
InferenceContext that = (InferenceContext) o;
60+
return Objects.equals(productUseCase, that.productUseCase);
61+
}
62+
63+
@Override
64+
public int hashCode() {
65+
return Objects.hashCode(productUseCase);
66+
}
67+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import org.elasticsearch.common.io.stream.StreamInput;
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.xpack.core.inference.InferenceContext;
1516

1617
import java.io.IOException;
18+
import java.util.Objects;
1719

1820
/**
1921
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops
@@ -23,8 +25,11 @@ public abstract class BaseInferenceActionRequest extends ActionRequest {
2325

2426
private boolean hasBeenRerouted;
2527

26-
public BaseInferenceActionRequest() {
28+
private final InferenceContext context;
29+
30+
public BaseInferenceActionRequest(InferenceContext context) {
2731
super();
32+
this.context = context;
2833
}
2934

3035
public BaseInferenceActionRequest(StreamInput in) throws IOException {
@@ -36,6 +41,13 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException {
3641
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
3742
this.hasBeenRerouted = true;
3843
}
44+
45+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)
46+
|| in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_CONTEXT_8_X)) {
47+
this.context = new InferenceContext(in);
48+
} else {
49+
this.context = InferenceContext.EMPTY_INSTANCE;
50+
}
3951
}
4052

4153
public abstract boolean isStreaming();
@@ -52,11 +64,33 @@ public boolean hasBeenRerouted() {
5264
return hasBeenRerouted;
5365
}
5466

67+
public InferenceContext getContext() {
68+
return context;
69+
}
70+
5571
@Override
5672
public void writeTo(StreamOutput out) throws IOException {
5773
super.writeTo(out);
5874
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
5975
out.writeBoolean(hasBeenRerouted);
6076
}
77+
78+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)
79+
|| out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_CONTEXT_8_X)) {
80+
context.writeTo(out);
81+
}
82+
}
83+
84+
@Override
85+
public boolean equals(Object o) {
86+
if (this == o) return true;
87+
if (o == null || getClass() != o.getClass()) return false;
88+
BaseInferenceActionRequest that = (BaseInferenceActionRequest) o;
89+
return hasBeenRerouted == that.hasBeenRerouted && Objects.equals(context, that.context);
90+
}
91+
92+
@Override
93+
public int hashCode() {
94+
return Objects.hash(hasBeenRerouted, context);
6195
}
6296
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xcontent.ParseField;
2929
import org.elasticsearch.xcontent.ToXContent;
3030
import org.elasticsearch.xcontent.XContentParser;
31+
import org.elasticsearch.xpack.core.inference.InferenceContext;
3132
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
3233
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3334
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
@@ -74,12 +75,14 @@ public static class Request extends BaseInferenceActionRequest {
7475
InputType.UNSPECIFIED
7576
);
7677

77-
public static Builder parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {
78+
public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser)
79+
throws IOException {
7880
Request.Builder builder = PARSER.apply(parser, null);
7981
builder.setInferenceEntityId(inferenceEntityId);
8082
builder.setTaskType(taskType);
8183
// For rest requests we won't know what the input type is
8284
builder.setInputType(InputType.UNSPECIFIED);
85+
builder.setContext(context);
8386
return builder;
8487
}
8588

@@ -102,6 +105,31 @@ public Request(
102105
TimeValue inferenceTimeout,
103106
boolean stream
104107
) {
108+
this(
109+
taskType,
110+
inferenceEntityId,
111+
query,
112+
input,
113+
taskSettings,
114+
inputType,
115+
inferenceTimeout,
116+
stream,
117+
InferenceContext.EMPTY_INSTANCE
118+
);
119+
}
120+
121+
public Request(
122+
TaskType taskType,
123+
String inferenceEntityId,
124+
String query,
125+
List<String> input,
126+
Map<String, Object> taskSettings,
127+
InputType inputType,
128+
TimeValue inferenceTimeout,
129+
boolean stream,
130+
InferenceContext context
131+
) {
132+
super(context);
105133
this.taskType = taskType;
106134
this.inferenceEntityId = inferenceEntityId;
107135
this.query = query;
@@ -241,19 +269,31 @@ static InputType getInputTypeToWrite(InputType inputType, TransportVersion versi
241269
public boolean equals(Object o) {
242270
if (this == o) return true;
243271
if (o == null || getClass() != o.getClass()) return false;
272+
if (super.equals(o) == false) return false;
244273
Request request = (Request) o;
245-
return taskType == request.taskType
274+
return stream == request.stream
275+
&& taskType == request.taskType
246276
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
277+
&& Objects.equals(query, request.query)
247278
&& Objects.equals(input, request.input)
248279
&& Objects.equals(taskSettings, request.taskSettings)
249-
&& Objects.equals(inputType, request.inputType)
250-
&& Objects.equals(query, request.query)
280+
&& inputType == request.inputType
251281
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
252282
}
253283

254284
@Override
255285
public int hashCode() {
256-
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout);
286+
return Objects.hash(
287+
super.hashCode(),
288+
taskType,
289+
inferenceEntityId,
290+
query,
291+
input,
292+
taskSettings,
293+
inputType,
294+
inferenceTimeout,
295+
stream
296+
);
257297
}
258298

259299
public static class Builder {
@@ -266,6 +306,7 @@ public static class Builder {
266306
private String query;
267307
private TimeValue timeout = DEFAULT_TIMEOUT;
268308
private boolean stream = false;
309+
private InferenceContext context;
269310

270311
private Builder() {}
271312

@@ -313,8 +354,13 @@ public Builder setStream(boolean stream) {
313354
return this;
314355
}
315356

357+
public Builder setContext(InferenceContext context) {
358+
this.context = context;
359+
return this;
360+
}
361+
316362
public Request build() {
317-
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
363+
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
318364
}
319365
}
320366

@@ -333,6 +379,8 @@ public String toString() {
333379
+ this.getInputType()
334380
+ ", timeout="
335381
+ this.getInferenceTimeout()
382+
+ ", context="
383+
+ this.getContext()
336384
+ ")";
337385
}
338386
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.action;
99

10+
import org.elasticsearch.TransportVersions;
1011
import org.elasticsearch.action.ActionRequest;
1112
import org.elasticsearch.action.ActionRequestValidationException;
1213
import org.elasticsearch.action.ActionType;
@@ -17,6 +18,7 @@
1718
import org.elasticsearch.core.TimeValue;
1819
import org.elasticsearch.inference.TaskType;
1920
import org.elasticsearch.xcontent.XContentType;
21+
import org.elasticsearch.xpack.core.inference.InferenceContext;
2022

2123
import java.io.IOException;
2224
import java.util.Objects;
@@ -44,21 +46,24 @@ public static class Request extends ActionRequest {
4446
private final XContentType contentType;
4547
private final TimeValue timeout;
4648
private final boolean stream;
49+
private final InferenceContext context;
4750

4851
public Request(
4952
TaskType taskType,
5053
String inferenceEntityId,
5154
BytesReference content,
5255
XContentType contentType,
5356
TimeValue timeout,
54-
boolean stream
57+
boolean stream,
58+
InferenceContext context
5559
) {
5660
this.taskType = taskType;
5761
this.inferenceEntityId = inferenceEntityId;
5862
this.content = content;
5963
this.contentType = contentType;
6064
this.timeout = timeout;
6165
this.stream = stream;
66+
this.context = context;
6267
}
6368

6469
public Request(StreamInput in) throws IOException {
@@ -71,6 +76,13 @@ public Request(StreamInput in) throws IOException {
7176

7277
// streaming is not supported yet for transport traffic
7378
this.stream = false;
79+
80+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)
81+
|| in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_CONTEXT_8_X)) {
82+
this.context = new InferenceContext(in);
83+
} else {
84+
this.context = InferenceContext.EMPTY_INSTANCE;
85+
}
7486
}
7587

7688
public TaskType getTaskType() {
@@ -97,6 +109,10 @@ public boolean isStreaming() {
97109
return stream;
98110
}
99111

112+
public InferenceContext getContext() {
113+
return context;
114+
}
115+
100116
@Override
101117
public ActionRequestValidationException validate() {
102118
return null;
@@ -110,6 +126,11 @@ public void writeTo(StreamOutput out) throws IOException {
110126
out.writeBytesReference(content);
111127
XContentHelper.writeTo(out, contentType);
112128
out.writeTimeValue(timeout);
129+
130+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT)
131+
|| out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_CONTEXT_8_X)) {
132+
context.writeTo(out);
133+
}
113134
}
114135

115136
@Override
@@ -122,12 +143,13 @@ public boolean equals(Object o) {
122143
&& Objects.equals(content, request.content)
123144
&& contentType == request.contentType
124145
&& timeout == request.timeout
125-
&& stream == request.stream;
146+
&& stream == request.stream
147+
&& context == request.context;
126148
}
127149

128150
@Override
129151
public int hashCode() {
130-
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream);
152+
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream, context);
131153
}
132154
}
133155
}

0 commit comments

Comments
 (0)