Skip to content

Commit ca79955

Browse files
committed
Iter (product use case propagation works)
1 parent 2e8231a commit ca79955

File tree

9 files changed

+58
-22
lines changed

9 files changed

+58
-22
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ public ActionRequestValidationException validate() {
120120
@Override
121121
public void writeTo(StreamOutput out) throws IOException {
122122
super.writeTo(out);
123-
out.writeString(inferenceEntityId);
124123
taskType.writeTo(out);
124+
out.writeString(inferenceEntityId);
125125
out.writeBytesReference(content);
126126
XContentHelper.writeTo(out, contentType);
127127
out.writeTimeValue(timeout);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.inference.InputType;
1717
import org.elasticsearch.inference.TaskType;
1818
import org.elasticsearch.xcontent.json.JsonXContent;
19+
import org.elasticsearch.xpack.core.inference.InferenceContext;
1920
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
2021

2122
import java.io.IOException;
@@ -57,7 +58,7 @@ public void testParsing() throws IOException {
5758
}
5859
""";
5960
try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) {
60-
var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser).build();
61+
var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, InferenceContext.empty(), parser).build();
6162
assertThat(request.getInput(), contains("single text input"));
6263
}
6364

@@ -67,7 +68,7 @@ public void testParsing() throws IOException {
6768
}
6869
""";
6970
try (var parser = createParser(JsonXContent.jsonXContent, multiInputRequest)) {
70-
var request = InferenceAction.Request.parseRequest("model_id", TaskType.ANY, parser).build();
71+
var request = InferenceAction.Request.parseRequest("model_id", TaskType.ANY, InferenceContext.empty(), parser).build();
7172
assertThat(request.getInput(), contains("an array", "of", "inputs"));
7273
}
7374
}
@@ -173,7 +174,7 @@ public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
173174
}
174175
""";
175176
try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) {
176-
var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser).build();
177+
var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING,InferenceContext.empty(), parser).build();
177178
assertThat(request.getInputType(), is(InputType.UNSPECIFIED));
178179
}
179180
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.core.TimeValue;
1515
import org.elasticsearch.inference.TaskType;
1616
import org.elasticsearch.inference.UnifiedCompletionRequest;
17+
import org.elasticsearch.xpack.core.inference.InferenceContext;
1718
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1819

1920
import java.io.IOException;
@@ -28,6 +29,7 @@ public void testValidation_ReturnsException_When_UnifiedCompletionRequestMessage
2829
"inference_id",
2930
TaskType.COMPLETION,
3031
UnifiedCompletionRequest.of(null),
32+
InferenceContext.empty(),
3133
TimeValue.timeValueSeconds(10)
3234
);
3335
var exception = request.validate();
@@ -39,6 +41,7 @@ public void testValidation_ReturnsException_When_UnifiedCompletionRequest_Is_Emp
3941
"inference_id",
4042
TaskType.COMPLETION,
4143
UnifiedCompletionRequest.of(List.of()),
44+
InferenceContext.empty(),
4245
TimeValue.timeValueSeconds(10)
4346
);
4447
var exception = request.validate();
@@ -50,6 +53,7 @@ public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() {
5053
"inference_id",
5154
TaskType.SPARSE_EMBEDDING,
5255
UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())),
56+
InferenceContext.empty(),
5357
TimeValue.timeValueSeconds(10)
5458
);
5559
var exception = request.validate();
@@ -61,6 +65,7 @@ public void testValidation_ReturnsNull_When_TaskType_IsAny() {
6165
"inference_id",
6266
TaskType.ANY,
6367
UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())),
68+
InferenceContext.empty(),
6469
TimeValue.timeValueSeconds(10)
6570
);
6671
assertNull(request.validate());
@@ -71,6 +76,7 @@ public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeen
7176
"model",
7277
TaskType.ANY,
7378
UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())),
79+
InferenceContext.empty(),
7480
TimeValue.timeValueSeconds(10)
7581
);
7682

@@ -101,6 +107,7 @@ protected UnifiedCompletionAction.Request createTestInstance() {
101107
randomAlphaOfLength(10),
102108
randomFrom(TaskType.values()),
103109
UnifiedCompletionRequestTests.randomUnifiedCompletionRequest(),
110+
InferenceContext.empty(),
104111
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
105112
);
106113
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import java.io.IOException;
5050
import java.util.HashMap;
5151
import java.util.Map;
52+
import java.util.Objects;
5253
import java.util.Random;
5354
import java.util.concurrent.Executor;
5455
import java.util.concurrent.Flow;
@@ -147,8 +148,11 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
147148
}
148149

149150
// TODO: test
150-
threadPool.getThreadContext()
151-
.putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, request.getContext().productUseCase());
151+
var context = request.getContext();
152+
if(Objects.nonNull(context)){
153+
threadPool.getThreadContext()
154+
.putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, context.productUseCase());
155+
}
152156

153157
var service = serviceRegistry.getService(serviceName).get();
154158
var localNodeId = nodeClient.getLocalNodeId();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.support.ActionFilters;
1313
import org.elasticsearch.action.support.HandledTransportAction;
14-
import org.elasticsearch.client.internal.OriginSettingClient;
14+
import org.elasticsearch.client.internal.Client;
1515
import org.elasticsearch.common.util.concurrent.EsExecutors;
1616
import org.elasticsearch.common.xcontent.XContentHelper;
1717
import org.elasticsearch.inference.TaskType;
@@ -35,14 +35,14 @@
3535
// TODO: test
3636
public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
3737
private final ModelRegistry modelRegistry;
38-
private final OriginSettingClient client;
38+
private final Client client;
3939

4040
@Inject
4141
public TransportInferenceActionProxy(
4242
TransportService transportService,
4343
ActionFilters actionFilters,
4444
ModelRegistry modelRegistry,
45-
OriginSettingClient client
45+
Client client
4646
) {
4747
super(
4848
InferenceActionProxy.NAME,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1414
import org.elasticsearch.xpack.inference.external.request.Request;
1515

16+
import java.util.Objects;
17+
1618
import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER;
1719

1820
public abstract class ElasticInferenceServiceRequest implements Request {
@@ -31,8 +33,18 @@ public ElasticInferenceServiceRequestMetadata getMetadata() {
3133
public final HttpRequest createHttpRequest() {
3234
HttpRequestBase request = createHttpRequestBase();
3335
// TODO: consider moving tracing here, too
34-
request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, metadata.productOrigin());
35-
request.setHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, metadata.productUseCase());
36+
37+
var productOrigin = metadata.productOrigin();
38+
var productUseCase = metadata.productUseCase();
39+
40+
if(Objects.nonNull(productOrigin) && productOrigin.isEmpty() == false){
41+
request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, metadata.productOrigin());
42+
}
43+
44+
if(Objects.nonNull(productUseCase) && productUseCase.isEmpty() == false){
45+
request.setHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, metadata.productUseCase());
46+
}
47+
3648
return new HttpRequest(request, getInferenceEntityId());
3749
}
3850

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.inference.InferencePlugin;
2121

2222
import java.io.IOException;
23+
import java.util.Objects;
2324

2425
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
2526
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID;
@@ -70,11 +71,17 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
7071
private String extractProductUseCase(RestRequest restRequest) {
7172
var headers = restRequest.getHeaders();
7273

73-
if (headers.isEmpty()) {
74+
if (Objects.isNull(headers) || headers.isEmpty()) {
75+
return "";
76+
}
77+
78+
var productUseCaseHeaders = headers.get(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER);
79+
80+
if(Objects.isNull(productUseCaseHeaders) || productUseCaseHeaders.isEmpty()){
7481
return "";
7582
}
7683

7784
// We always get the first value as the header doesn't allow multiple values
78-
return headers.get(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER).getFirst();
85+
return productUseCaseHeaders.getFirst();
7986
}
8087
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxyTests.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.action.support.ActionFilters;
1212
import org.elasticsearch.action.support.PlainActionFuture;
13-
import org.elasticsearch.client.internal.OriginSettingClient;
13+
import org.elasticsearch.client.internal.Client;
1414
import org.elasticsearch.common.bytes.BytesArray;
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.inference.TaskType;
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.threadpool.ThreadPool;
2222
import org.elasticsearch.transport.TransportService;
2323
import org.elasticsearch.xcontent.XContentType;
24+
import org.elasticsearch.xpack.core.inference.InferenceContext;
2425
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2526
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
2627
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
@@ -39,15 +40,15 @@
3940
import static org.mockito.Mockito.when;
4041

4142
public class TransportInferenceActionProxyTests extends ESTestCase {
42-
private OriginSettingClient client;
43+
private Client client;
4344
private ThreadPool threadPool;
4445
private TransportInferenceActionProxy action;
4546
private ModelRegistry modelRegistry;
4647

4748
@Before
4849
public void setUp() throws Exception {
4950
super.setUp();
50-
client = mock(OriginSettingClient.class);
51+
client = mock(Client.class);
5152
threadPool = new TestThreadPool("test");
5253
when(client.threadPool()).thenReturn(threadPool);
5354
modelRegistry = mock(ModelRegistry.class);
@@ -87,7 +88,8 @@ public void testExecutesAUnifiedCompletionRequest_WhenTaskTypeIsChatCompletion_I
8788
new BytesArray(requestJson),
8889
XContentType.JSON,
8990
TimeValue.ONE_MINUTE,
90-
true
91+
true,
92+
InferenceContext.empty()
9193
);
9294

9395
action.doExecute(mock(Task.class), request, listener);
@@ -129,7 +131,8 @@ public void testExecutesAUnifiedCompletionRequest_WhenTaskTypeIsChatCompletion_F
129131
new BytesArray(requestJson),
130132
XContentType.JSON,
131133
TimeValue.ONE_MINUTE,
132-
true
134+
true,
135+
InferenceContext.empty()
133136
);
134137

135138
action.doExecute(mock(Task.class), request, listener);
@@ -152,7 +155,8 @@ public void testExecutesAnInferenceAction_WhenTaskTypeIsCompletion_InRequest() {
152155
new BytesArray(requestJson),
153156
XContentType.JSON,
154157
TimeValue.ONE_MINUTE,
155-
true
158+
true,
159+
InferenceContext.empty()
156160
);
157161

158162
action.doExecute(mock(Task.class), request, listener);
@@ -181,7 +185,8 @@ public void testExecutesAnInferenceAction_WhenTaskTypeIsCompletion_FromStorage()
181185
new BytesArray(requestJson),
182186
XContentType.JSON,
183187
TimeValue.ONE_MINUTE,
184-
true
188+
true,
189+
InferenceContext.empty()
185190
);
186191

187192
action.doExecute(mock(Task.class), request, listener);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequestTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ public String getInferenceEntityId() {
7979

8080
public static ElasticInferenceServiceRequestMetadata randomElasticInferenceServiceRequestMetadata() {
8181
return new ElasticInferenceServiceRequestMetadata(
82-
randomFrom(null, randomAlphaOfLength(10)),
83-
randomFrom(null, randomAlphaOfLength(10))
82+
randomFrom(new String[]{null, randomAlphaOfLength(10)}),
83+
randomFrom(new String[]{null, randomAlphaOfLength(10)})
8484
);
8585
}
8686
}

0 commit comments

Comments
 (0)