Skip to content

Commit 93cc995

Browse files
committed
Address comments
1 parent 0feba86 commit 93cc995

File tree

9 files changed

+63
-77
lines changed

9 files changed

+63
-77
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,19 @@
1818
import org.elasticsearch.xpack.inference.external.request.Request;
1919
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
2020
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
21-
import org.elasticsearch.xpack.inference.telemetry.TraceContextAware;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
2222

2323
import java.net.URI;
2424
import java.nio.charset.StandardCharsets;
2525
import java.util.Objects;
2626

27-
public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest, TraceContextAware {
27+
public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {
2828

2929
private final URI uri;
30-
3130
private final ElasticInferenceServiceSparseEmbeddingsModel model;
32-
3331
private final Truncator.TruncationResult truncationResult;
3432
private final Truncator truncator;
35-
36-
private final TraceContext traceContext;
33+
private final TraceContextHandler traceContextHandler;
3734

3835
public ElasticInferenceServiceSparseEmbeddingsRequest(
3936
Truncator truncator,
@@ -45,7 +42,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
4542
this.truncationResult = truncationResult;
4643
this.model = Objects.requireNonNull(model);
4744
this.uri = model.uri();
48-
this.traceContext = traceContext;
45+
this.traceContextHandler = new TraceContextHandler(traceContext);
4946
}
5047

5148
@Override
@@ -56,13 +53,16 @@ public HttpRequest createHttpRequest() {
5653
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
5754
httpPost.setEntity(byteEntity);
5855

59-
propagateTraceContext(httpPost);
60-
56+
traceContextHandler.propagateTraceContext(httpPost);
6157
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
6258

6359
return new HttpRequest(httpPost, getInferenceEntityId());
6460
}
6561

62+
public TraceContext getTraceContext() {
63+
return traceContextHandler.traceContext();
64+
}
65+
6666
@Override
6767
public String getInferenceEntityId() {
6868
return model.getInferenceEntityId();
@@ -73,20 +73,15 @@ public URI getURI() {
7373
return this.uri;
7474
}
7575

76-
@Override
77-
public TraceContext getTraceContext() {
78-
return traceContext;
79-
}
80-
8176
@Override
8277
public Request truncate() {
8378
var truncatedInput = truncator.truncate(truncationResult.input());
84-
85-
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContext);
79+
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContextHandler.traceContext());
8680
}
8781

8882
@Override
8983
public boolean[] getTruncationInfo() {
9084
return truncationResult.truncated().clone();
9185
}
86+
9287
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,17 @@
1818
import org.elasticsearch.xpack.inference.external.request.Request;
1919
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
2020
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
21-
import org.elasticsearch.xpack.inference.telemetry.TraceContextAware;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
2222

2323
import java.net.URI;
2424
import java.nio.charset.StandardCharsets;
2525
import java.util.Objects;
2626

27-
public class ElasticInferenceServiceUnifiedChatCompletionRequest implements TraceContextAware, Request {
27+
public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {
2828

2929
private final ElasticInferenceServiceCompletionModel model;
3030
private final UnifiedChatInput unifiedChatInput;
31-
private final URI uri;
32-
private final TraceContext traceContext;
31+
private final TraceContextHandler traceContextHandler;
3332

3433
public ElasticInferenceServiceUnifiedChatCompletionRequest(
3534
UnifiedChatInput unifiedChatInput,
@@ -38,33 +37,28 @@ public ElasticInferenceServiceUnifiedChatCompletionRequest(
3837
) {
3938
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
4039
this.model = Objects.requireNonNull(model);
41-
this.uri = model.uri();
42-
this.traceContext = traceContext;
43-
40+
this.traceContextHandler = new TraceContextHandler(traceContext);
4441
}
4542

4643
@Override
4744
public HttpRequest createHttpRequest() {
48-
var httpPost = new HttpPost(uri);
45+
var httpPost = new HttpPost(model.uri());
4946
var requestEntity = Strings.toString(
5047
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
5148
);
5249

5350
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
5451
httpPost.setEntity(byteEntity);
5552

56-
if (traceContext != null) {
57-
propagateTraceContext(httpPost);
58-
}
59-
53+
traceContextHandler.propagateTraceContext(httpPost);
6054
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
6155

6256
return new HttpRequest(httpPost, getInferenceEntityId());
6357
}
6458

6559
@Override
6660
public URI getURI() {
67-
return uri;
61+
return model.uri();
6862
}
6963

7064
@Override
@@ -88,9 +82,4 @@ public String getInferenceEntityId() {
8882
public boolean isStreaming() {
8983
return true;
9084
}
91-
92-
@Override
93-
public TraceContext getTraceContext() {
94-
return traceContext;
95-
}
9685
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
111111
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
112112
}
113113

114-
// Underlying providers except OpenAI only return 1 possible choice.
114+
// Underlying providers expect OpenAI to only return 1 possible choice.
115115
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
116116

117117
if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor;
2020
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
2121
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
22+
import org.elasticsearch.ElasticsearchStatusException;
23+
import org.elasticsearch.rest.RestStatus;
2224

2325
import java.net.URI;
2426
import java.net.URISyntaxException;
@@ -57,12 +59,7 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
5759
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings
5860
) {
5961
super(model, serviceSettings);
60-
61-
try {
62-
this.uri = createUri();
63-
} catch (URISyntaxException e) {
64-
throw new RuntimeException(e);
65-
}
62+
this.uri = createUri();
6663
}
6764

6865
ElasticInferenceServiceSparseEmbeddingsModel(
@@ -80,12 +77,7 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
8077
serviceSettings,
8178
elasticInferenceServiceComponents
8279
);
83-
84-
try {
85-
this.uri = createUri();
86-
} catch (URISyntaxException e) {
87-
throw new RuntimeException(e);
88-
}
80+
this.uri = createUri();
8981
}
9082

9183
@Override
@@ -102,19 +94,29 @@ public URI uri() {
10294
return uri;
10395
}
10496

105-
private URI createUri() throws URISyntaxException {
97+
private URI createUri() throws ElasticsearchStatusException {
10698
String modelId = getServiceSettings().modelId();
10799
String modelIdUriPath;
108100

109101
switch (modelId) {
110102
case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2";
111-
default -> throw new IllegalArgumentException(
112-
String.format(Locale.ROOT, "Unsupported model for %s [%s]", ELASTIC_INFERENCE_SERVICE_IDENTIFIER, modelId)
103+
default -> throw new ElasticsearchStatusException(
104+
String.format(Locale.ROOT, "Unsupported model for %s [%s]", ELASTIC_INFERENCE_SERVICE_IDENTIFIER, modelId),
105+
RestStatus.BAD_REQUEST
113106
);
114107
}
115108

116-
return new URI(
117-
elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/sparse-text-embeddings/" + modelIdUriPath
118-
);
109+
try {
110+
// TODO, consider transforming the base URL into a URI for better error handling.
111+
return new URI(
112+
elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/sparse-text-embeddings/" + modelIdUriPath
113+
);
114+
} catch (URISyntaxException e) {
115+
throw new ElasticsearchStatusException(
116+
"Failed to create URI for sparse embeddings service: " + e.getMessage(),
117+
RestStatus.BAD_REQUEST,
118+
e
119+
);
120+
}
119121
}
120122
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.services.elastic.completion;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.core.Nullable;
1112
import org.elasticsearch.inference.EmptySecretSettings;
1213
import org.elasticsearch.inference.EmptyTaskSettings;
@@ -16,6 +17,7 @@
1617
import org.elasticsearch.inference.TaskSettings;
1718
import org.elasticsearch.inference.TaskType;
1819
import org.elasticsearch.inference.UnifiedCompletionRequest;
20+
import org.elasticsearch.rest.RestStatus;
1921
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
2022
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
2123
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;
@@ -68,12 +70,8 @@ public ElasticInferenceServiceCompletionModel(
6870
ElasticInferenceServiceCompletionServiceSettings serviceSettings
6971
) {
7072
super(model, serviceSettings);
73+
this.uri = createUri();
7174

72-
try {
73-
this.uri = createUri();
74-
} catch (URISyntaxException e) {
75-
throw new RuntimeException(e);
76-
}
7775
}
7876

7977
ElasticInferenceServiceCompletionModel(
@@ -92,11 +90,8 @@ public ElasticInferenceServiceCompletionModel(
9290
elasticInferenceServiceComponents
9391
);
9492

95-
try {
96-
this.uri = createUri();
97-
} catch (URISyntaxException e) {
98-
throw new RuntimeException(e);
99-
}
93+
this.uri = createUri();
94+
10095
}
10196

10297
@Override
@@ -108,9 +103,18 @@ public URI uri() {
108103
return uri;
109104
}
110105

111-
private URI createUri() throws URISyntaxException {
112-
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/chat/completions");
106+
private URI createUri() throws ElasticsearchStatusException {
107+
try {
108+
// TODO, consider transforming the base URL into a URI for better error handling.
109+
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/chat/completions");
110+
} catch (URISyntaxException e) {
111+
throw new ElasticsearchStatusException(
112+
"Failed to create URI for completion service: " + e.getMessage(),
113+
RestStatus.BAD_REQUEST,
114+
e
115+
);
116+
}
113117
}
114118

115-
// TODO create the Configuration class?
119+
// TODO create/refactor the Configuration class to be extensible for different task types (i.e completion, sparse embeddings).
116120
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXC
3636
public static final String NAME = "elastic_inference_service_completion_service_settings";
3737

3838
// TODO what value do we put here?
39-
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1_000);
39+
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240L);
4040

4141
public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
4242
ValidationException validationException = new ValidationException();
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
* 2.0; you may not use this file except in compliance with the Elastic License
55
* 2.0.
66
*/
7+
78
package org.elasticsearch.xpack.inference.telemetry;
89

910
import org.apache.http.client.methods.HttpPost;
1011
import org.elasticsearch.tasks.Task;
1112

12-
public interface TraceContextAware {
13-
TraceContext getTraceContext();
13+
public record TraceContextHandler(TraceContext traceContext) {
1414

15-
default void propagateTraceContext(HttpPost httpPost) {
16-
TraceContext traceContext = this.getTraceContext();
15+
public void propagateTraceContext(HttpPost httpPost) {
1716
if (traceContext == null) {
1817
return;
1918
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
public class ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests extends ESTestCase {
2727

2828
private static final String ROLE = "user";
29-
private static final String USER = "a_user";
3029

31-
// TODO remove if EIS doesn't use the model and user fields
3230
public void testModelUserFieldsSerialization() throws IOException {
3331
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
3432
new UnifiedCompletionRequest.ContentString("Hello, world!"),
@@ -43,7 +41,7 @@ public void testModelUserFieldsSerialization() throws IOException {
4341
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
4442

4543
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
46-
OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER);
44+
OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null);
4745

4846
OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);
4947

@@ -64,8 +62,7 @@ public void testModelUserFieldsSerialization() throws IOException {
6462
"stream": true,
6563
"stream_options": {
6664
"include_usage": true
67-
},
68-
"user": "a_user"
65+
}
6966
}
7067
""";
7168
assertJsonEquals(jsonString, expectedJson);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/EISCompletionServiceSettingsTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void testFromMap() {
5252
ConfigurationParseContext.REQUEST
5353
);
5454

55-
assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000))));
55+
assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(240L))));
5656
}
5757

5858
public void testFromMap_MissingModelId_ThrowsException() {

0 commit comments

Comments
 (0)