Skip to content

Commit 467747f

Browse files
Adding model
1 parent ccec39b commit 467747f

File tree

4 files changed

+62
-37
lines changed

4 files changed

+62
-37
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,15 @@ public HttpRequest createHttpRequest() {
4444
HttpPost httpPost = new HttpPost(account.uri());
4545

4646
ByteArrayEntity byteEntity = new ByteArrayEntity(
47-
Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8)
47+
Strings.toString(
48+
new OpenAiUnifiedChatCompletionRequestEntity(
49+
unifiedChatInput,
50+
new OpenAiUnifiedChatCompletionRequestEntity.ModelFields(
51+
model.getServiceSettings().modelId(),
52+
model.getTaskSettings().user()
53+
)
54+
)
55+
).getBytes(StandardCharsets.UTF_8)
4856
);
4957
httpPost.setEntity(byteEntity);
5058

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
package org.elasticsearch.xpack.inference.external.request.openai;
99

1010
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.core.Nullable;
1112
import org.elasticsearch.inference.UnifiedCompletionRequest;
1213
import org.elasticsearch.xcontent.ToXContentObject;
1314
import org.elasticsearch.xcontent.XContentBuilder;
1415
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
15-
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
1616

1717
import java.io.IOException;
1818
import java.util.Objects;
@@ -48,16 +48,18 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
4848

4949
private final UnifiedCompletionRequest unifiedRequest;
5050
private final boolean stream;
51-
private final OpenAiChatCompletionModel model;
51+
private final ModelFields modelFields;
5252

53-
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
53+
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, ModelFields modelFields) {
5454
Objects.requireNonNull(unifiedChatInput);
5555

5656
this.unifiedRequest = unifiedChatInput.getRequest();
5757
this.stream = unifiedChatInput.stream();
58-
this.model = Objects.requireNonNull(model);
58+
this.modelFields = Objects.requireNonNull(modelFields);
5959
}
6060

61+
public record ModelFields(String modelId, @Nullable String user) {}
62+
6163
@Override
6264
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
6365
builder.startObject();
@@ -111,7 +113,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
111113
}
112114
builder.endArray();
113115

114-
builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
116+
builder.field(MODEL_FIELD, modelFields.modelId());
115117
if (unifiedRequest.maxCompletionTokens() != null) {
116118
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
117119
}
@@ -168,8 +170,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
168170
builder.field(TOP_P_FIELD, unifiedRequest.topP());
169171
}
170172

171-
if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) {
172-
builder.field(USER_FIELD, model.getTaskSettings().user());
173+
if (Strings.isNullOrEmpty(modelFields.user()) == false) {
174+
builder.field(USER_FIELD, modelFields.user());
173175
}
174176

175177
builder.field(STREAM_FIELD, stream);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,35 @@
3434
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
3535
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3636
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
37+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3738
import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionCreator;
3839
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
3940
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
4041
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
42+
import org.elasticsearch.xpack.inference.external.http.sender.OpenAiUnifiedCompletionRequestManager;
4143
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
4244
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
4345
import org.elasticsearch.xpack.inference.services.SenderService;
4446
import org.elasticsearch.xpack.inference.services.ServiceComponents;
47+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
4548
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4649
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
4750

4851
import java.util.EnumSet;
4952
import java.util.HashMap;
5053
import java.util.List;
54+
import java.util.Locale;
5155
import java.util.Map;
5256

5357
import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
58+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
5459
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
5560
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
5661
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
5762
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
5863
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5964
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
6065
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
61-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6266

6367
public class ElasticInferenceService extends SenderService {
6468

@@ -85,7 +89,28 @@ protected void doUnifiedCompletionInfer(
8589
TimeValue timeout,
8690
ActionListener<InferenceServiceResults> listener
8791
) {
88-
throwUnsupportedUnifiedCompletionOperation(NAME);
92+
if (model instanceof ElasticInferenceServiceCompletionModel == false) {
93+
listener.onFailure(createInvalidModelException(model));
94+
return;
95+
}
96+
97+
// We extract the trace context here as it's sufficient to propagate the trace information of the REST request,
98+
// which handles the request to the inference API overall (including the outgoing request, which is started in a new thread
99+
// generating a different "traceparent" as every task and every REST request creates a new span).
100+
var currentTraceInfo = getCurrentTraceInfo();
101+
102+
var completionModel = (ElasticInferenceServiceCompletionModel) model;
103+
var overriddenModel = ElasticInferenceServiceCompletionModel.of(completionModel, inputs.getRequest());
104+
var errorMessage = constructFailedToSendRequestMessage(
105+
overriddenModel.uri(),
106+
String.format(Locale.ROOT, "%s completions", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
107+
);
108+
109+
// TODO add the request manager that takes a trace context
110+
var requestCreator = OpenAiUnifiedCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool());
111+
var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage);
112+
113+
action.execute(inputs, timeout, listener);
89114
}
90115

91116
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,30 @@
1616
import org.elasticsearch.inference.TaskSettings;
1717
import org.elasticsearch.inference.TaskType;
1818
import org.elasticsearch.inference.UnifiedCompletionRequest;
19+
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest;
1920
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
2021
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
2122
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;
2223
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
23-
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
24-
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
25-
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
2624

2725
import java.net.URI;
2826
import java.net.URISyntaxException;
29-
import java.util.Locale;
3027
import java.util.Map;
3128
import java.util.Objects;
3229

33-
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
34-
3530
public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel {
3631

37-
public static ElasticInferenceServiceCompletionModel of(ElasticInferenceServiceCompletionModel model, UnifiedCompletionRequest request) {
32+
public static ElasticInferenceServiceCompletionModel of(
33+
ElasticInferenceServiceCompletionModel model,
34+
UnifiedCompletionRequest request
35+
) {
3836
var originalModelServiceSettings = model.getServiceSettings();
3937
var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings(
4038
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()),
4139
originalModelServiceSettings.rateLimitSettings()
4240
);
4341

44-
return new ElasticInferenceServiceCompletionModel(
45-
model.getInferenceEntityId(),
46-
model.getTaskType(),
47-
model.getConfigurations().getService(),
48-
overriddenServiceSettings,
49-
model.getTaskSettings(),
50-
model.getSecretSettings()
51-
);
42+
return new ElasticInferenceServiceCompletionModel(model, overriddenServiceSettings);
5243
}
5344

5445
private final URI uri;
@@ -76,7 +67,7 @@ public ElasticInferenceServiceCompletionModel(
7667

7768
public ElasticInferenceServiceCompletionModel(
7869
ElasticInferenceServiceCompletionModel model,
79-
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings
70+
ElasticInferenceServiceCompletionServiceSettings serviceSettings
8071
) {
8172
super(model, serviceSettings);
8273

@@ -121,18 +112,17 @@ public URI uri() {
121112

122113
private URI createUri() throws URISyntaxException {
123114
String modelId = getServiceSettings().modelId();
124-
// String modelIdUriPath;
125-
//
126-
// switch (modelId) {
127-
// case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2";
128-
// default -> throw new IllegalArgumentException(
129-
// String.format(Locale.ROOT, "Unsupported model for %s [%s]", ELASTIC_INFERENCE_SERVICE_IDENTIFIER, modelId)
130-
// );
131-
// }
115+
// String modelIdUriPath;
116+
//
117+
// switch (modelId) {
118+
// case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2";
119+
// default -> throw new IllegalArgumentException(
120+
// String.format(Locale.ROOT, "Unsupported model for %s [%s]", ELASTIC_INFERENCE_SERVICE_IDENTIFIER, modelId)
121+
// );
122+
// }
132123

133124
// TODO what is the url?
134-
// return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/completion/" + modelIdUriPath);
135-
136-
return
125+
// return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/completion/" + modelId);
126+
return OpenAiUnifiedChatCompletionRequest.buildDefaultUri();
137127
}
138128
}

0 commit comments

Comments
 (0)