Skip to content

Commit 07552e8

Browse files
authored
[8.x] [Inference API] Propagate product use case http header to EIS (#124025) (#124666)
1 parent c36b500 commit 07552e8

File tree

33 files changed

+896
-101
lines changed

33 files changed

+896
-101
lines changed

docs/changelog/124025.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 124025
2+
summary: "[Inference API] Propagate product use case http header to EIS"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ static TransportVersion def(int id) {
192192
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
193193
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
194194
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
195+
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
195196

196197
/*
197198
* STOP! READ THIS FIRST! No, really,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.xcontent.ToXContent;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
16+
import java.io.IOException;
17+
import java.util.Objects;
18+
19+
/**
20+
* Record for storing context alongside an inference request, typically used for metadata.
21+
* This is mainly used to pass along inference context on the transport layer without relying on
22+
* {@link org.elasticsearch.common.util.concurrent.ThreadContext}, which depending on the internal
23+
* {@link org.elasticsearch.client.internal.Client} throws away parts of the context, when passed along the transport layer.
24+
*
25+
* @param productUseCase - for now mainly used by Elastic Inference Service
26+
*/
27+
public record InferenceContext(String productUseCase) implements Writeable, ToXContent {
28+
29+
public static final InferenceContext EMPTY_INSTANCE = new InferenceContext("");
30+
31+
public InferenceContext {
32+
Objects.requireNonNull(productUseCase);
33+
}
34+
35+
public InferenceContext(StreamInput in) throws IOException {
36+
this(in.readString());
37+
}
38+
39+
@Override
40+
public void writeTo(StreamOutput out) throws IOException {
41+
out.writeString(productUseCase);
42+
}
43+
44+
@Override
45+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
46+
builder.startObject();
47+
48+
builder.field("product_use_case", productUseCase);
49+
50+
builder.endObject();
51+
52+
return builder;
53+
}
54+
55+
@Override
56+
public boolean equals(Object o) {
57+
if (this == o) return true;
58+
if (o == null || getClass() != o.getClass()) return false;
59+
InferenceContext that = (InferenceContext) o;
60+
return Objects.equals(productUseCase, that.productUseCase);
61+
}
62+
63+
@Override
64+
public int hashCode() {
65+
return Objects.hashCode(productUseCase);
66+
}
67+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import org.elasticsearch.common.io.stream.StreamInput;
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.xpack.core.inference.InferenceContext;
1516

1617
import java.io.IOException;
18+
import java.util.Objects;
1719

1820
/**
1921
* Base class for inference action requests. Tracks request routing state to prevent potential routing loops
@@ -23,8 +25,11 @@ public abstract class BaseInferenceActionRequest extends ActionRequest {
2325

2426
private boolean hasBeenRerouted;
2527

26-
public BaseInferenceActionRequest() {
28+
private final InferenceContext context;
29+
30+
public BaseInferenceActionRequest(InferenceContext context) {
2731
super();
32+
this.context = context;
2833
}
2934

3035
public BaseInferenceActionRequest(StreamInput in) throws IOException {
@@ -36,6 +41,12 @@ public BaseInferenceActionRequest(StreamInput in) throws IOException {
3641
// a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
3742
this.hasBeenRerouted = true;
3843
}
44+
45+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
46+
this.context = new InferenceContext(in);
47+
} else {
48+
this.context = InferenceContext.EMPTY_INSTANCE;
49+
}
3950
}
4051

4152
public abstract boolean isStreaming();
@@ -52,11 +63,32 @@ public boolean hasBeenRerouted() {
5263
return hasBeenRerouted;
5364
}
5465

66+
public InferenceContext getContext() {
67+
return context;
68+
}
69+
5570
@Override
5671
public void writeTo(StreamOutput out) throws IOException {
5772
super.writeTo(out);
5873
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
5974
out.writeBoolean(hasBeenRerouted);
6075
}
76+
77+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
78+
context.writeTo(out);
79+
}
80+
}
81+
82+
@Override
83+
public boolean equals(Object o) {
84+
if (this == o) return true;
85+
if (o == null || getClass() != o.getClass()) return false;
86+
BaseInferenceActionRequest that = (BaseInferenceActionRequest) o;
87+
return hasBeenRerouted == that.hasBeenRerouted && Objects.equals(context, that.context);
88+
}
89+
90+
@Override
91+
public int hashCode() {
92+
return Objects.hash(hasBeenRerouted, context);
6193
}
6294
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xcontent.ParseField;
2929
import org.elasticsearch.xcontent.ToXContent;
3030
import org.elasticsearch.xcontent.XContentParser;
31+
import org.elasticsearch.xpack.core.inference.InferenceContext;
3132
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
3233
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3334
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
@@ -74,12 +75,14 @@ public static class Request extends BaseInferenceActionRequest {
7475
InputType.UNSPECIFIED
7576
);
7677

77-
public static Builder parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {
78+
public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser)
79+
throws IOException {
7880
Request.Builder builder = PARSER.apply(parser, null);
7981
builder.setInferenceEntityId(inferenceEntityId);
8082
builder.setTaskType(taskType);
8183
// For rest requests we won't know what the input type is
8284
builder.setInputType(InputType.UNSPECIFIED);
85+
builder.setContext(context);
8386
return builder;
8487
}
8588

@@ -102,6 +105,31 @@ public Request(
102105
TimeValue inferenceTimeout,
103106
boolean stream
104107
) {
108+
this(
109+
taskType,
110+
inferenceEntityId,
111+
query,
112+
input,
113+
taskSettings,
114+
inputType,
115+
inferenceTimeout,
116+
stream,
117+
InferenceContext.EMPTY_INSTANCE
118+
);
119+
}
120+
121+
public Request(
122+
TaskType taskType,
123+
String inferenceEntityId,
124+
String query,
125+
List<String> input,
126+
Map<String, Object> taskSettings,
127+
InputType inputType,
128+
TimeValue inferenceTimeout,
129+
boolean stream,
130+
InferenceContext context
131+
) {
132+
super(context);
105133
this.taskType = taskType;
106134
this.inferenceEntityId = inferenceEntityId;
107135
this.query = query;
@@ -241,19 +269,31 @@ static InputType getInputTypeToWrite(InputType inputType, TransportVersion versi
241269
public boolean equals(Object o) {
242270
if (this == o) return true;
243271
if (o == null || getClass() != o.getClass()) return false;
272+
if (super.equals(o) == false) return false;
244273
Request request = (Request) o;
245-
return taskType == request.taskType
274+
return stream == request.stream
275+
&& taskType == request.taskType
246276
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
277+
&& Objects.equals(query, request.query)
247278
&& Objects.equals(input, request.input)
248279
&& Objects.equals(taskSettings, request.taskSettings)
249-
&& Objects.equals(inputType, request.inputType)
250-
&& Objects.equals(query, request.query)
280+
&& inputType == request.inputType
251281
&& Objects.equals(inferenceTimeout, request.inferenceTimeout);
252282
}
253283

254284
@Override
255285
public int hashCode() {
256-
return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout);
286+
return Objects.hash(
287+
super.hashCode(),
288+
taskType,
289+
inferenceEntityId,
290+
query,
291+
input,
292+
taskSettings,
293+
inputType,
294+
inferenceTimeout,
295+
stream
296+
);
257297
}
258298

259299
public static class Builder {
@@ -266,6 +306,7 @@ public static class Builder {
266306
private String query;
267307
private TimeValue timeout = DEFAULT_TIMEOUT;
268308
private boolean stream = false;
309+
private InferenceContext context;
269310

270311
private Builder() {}
271312

@@ -313,8 +354,13 @@ public Builder setStream(boolean stream) {
313354
return this;
314355
}
315356

357+
public Builder setContext(InferenceContext context) {
358+
this.context = context;
359+
return this;
360+
}
361+
316362
public Request build() {
317-
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream);
363+
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
318364
}
319365
}
320366

@@ -333,6 +379,8 @@ public String toString() {
333379
+ this.getInputType()
334380
+ ", timeout="
335381
+ this.getInferenceTimeout()
382+
+ ", context="
383+
+ this.getContext()
336384
+ ")";
337385
}
338386
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.action;
99

10+
import org.elasticsearch.TransportVersions;
1011
import org.elasticsearch.action.ActionRequest;
1112
import org.elasticsearch.action.ActionRequestValidationException;
1213
import org.elasticsearch.action.ActionType;
@@ -17,6 +18,7 @@
1718
import org.elasticsearch.core.TimeValue;
1819
import org.elasticsearch.inference.TaskType;
1920
import org.elasticsearch.xcontent.XContentType;
21+
import org.elasticsearch.xpack.core.inference.InferenceContext;
2022

2123
import java.io.IOException;
2224
import java.util.Objects;
@@ -44,21 +46,24 @@ public static class Request extends ActionRequest {
4446
private final XContentType contentType;
4547
private final TimeValue timeout;
4648
private final boolean stream;
49+
private final InferenceContext context;
4750

4851
public Request(
4952
TaskType taskType,
5053
String inferenceEntityId,
5154
BytesReference content,
5255
XContentType contentType,
5356
TimeValue timeout,
54-
boolean stream
57+
boolean stream,
58+
InferenceContext context
5559
) {
5660
this.taskType = taskType;
5761
this.inferenceEntityId = inferenceEntityId;
5862
this.content = content;
5963
this.contentType = contentType;
6064
this.timeout = timeout;
6165
this.stream = stream;
66+
this.context = context;
6267
}
6368

6469
public Request(StreamInput in) throws IOException {
@@ -71,6 +76,12 @@ public Request(StreamInput in) throws IOException {
7176

7277
// streaming is not supported yet for transport traffic
7378
this.stream = false;
79+
80+
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
81+
this.context = new InferenceContext(in);
82+
} else {
83+
this.context = InferenceContext.EMPTY_INSTANCE;
84+
}
7485
}
7586

7687
public TaskType getTaskType() {
@@ -97,6 +108,10 @@ public boolean isStreaming() {
97108
return stream;
98109
}
99110

111+
public InferenceContext getContext() {
112+
return context;
113+
}
114+
100115
@Override
101116
public ActionRequestValidationException validate() {
102117
return null;
@@ -110,6 +125,10 @@ public void writeTo(StreamOutput out) throws IOException {
110125
out.writeBytesReference(content);
111126
XContentHelper.writeTo(out, contentType);
112127
out.writeTimeValue(timeout);
128+
129+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_CONTEXT_8_X)) {
130+
context.writeTo(out);
131+
}
113132
}
114133

115134
@Override
@@ -122,12 +141,13 @@ public boolean equals(Object o) {
122141
&& Objects.equals(content, request.content)
123142
&& contentType == request.contentType
124143
&& timeout == request.timeout
125-
&& stream == request.stream;
144+
&& stream == request.stream
145+
&& context == request.context;
126146
}
127147

128148
@Override
129149
public int hashCode() {
130-
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream);
150+
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream, context);
131151
}
132152
}
133153
}

0 commit comments

Comments
 (0)