Skip to content

Commit 1fead0a

Browse files
Starting completion model
1 parent 4c2573e commit 1fead0a

File tree

7 files changed

+316
-9
lines changed

7 files changed

+316
-9
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ protected void doInfer(
9797
TimeValue timeout,
9898
ActionListener<InferenceServiceResults> listener
9999
) {
100-
if (model instanceof ElasticInferenceServiceModel == false) {
100+
if (model instanceof ElasticInferenceServiceExecutableActionModel == false) {
101101
listener.onFailure(createInvalidModelException(model));
102102
return;
103103
}
@@ -107,7 +107,7 @@ protected void doInfer(
107107
// generating a different "traceparent" as every task and every REST request creates a new span).
108108
var currentTraceInfo = getCurrentTraceInfo();
109109

110-
ElasticInferenceServiceModel elasticInferenceServiceModel = (ElasticInferenceServiceModel) model;
110+
ElasticInferenceServiceExecutableActionModel elasticInferenceServiceModel = (ElasticInferenceServiceExecutableActionModel) model;
111111
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), currentTraceInfo);
112112

113113
var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic;
9+
10+
import org.elasticsearch.inference.ModelConfigurations;
11+
import org.elasticsearch.inference.ModelSecrets;
12+
import org.elasticsearch.inference.ServiceSettings;
13+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
14+
import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor;
15+
16+
import java.util.Map;
17+
18+
public abstract class ElasticInferenceServiceExecutableActionModel extends ElasticInferenceServiceModel {
19+
20+
public ElasticInferenceServiceExecutableActionModel(
21+
ModelConfigurations configurations,
22+
ModelSecrets secrets,
23+
ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings,
24+
ElasticInferenceServiceComponents elasticInferenceServiceComponents
25+
) {
26+
super(configurations, secrets, rateLimitServiceSettings, elasticInferenceServiceComponents);
27+
}
28+
29+
public ElasticInferenceServiceExecutableActionModel(
30+
ElasticInferenceServiceExecutableActionModel model,
31+
ServiceSettings serviceSettings
32+
) {
33+
super(model, serviceSettings);
34+
}
35+
36+
public abstract ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map<String, Object> taskSettings);
37+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
import org.elasticsearch.inference.ModelConfigurations;
1212
import org.elasticsearch.inference.ModelSecrets;
1313
import org.elasticsearch.inference.ServiceSettings;
14-
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
15-
import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor;
1614

17-
import java.util.Map;
1815
import java.util.Objects;
1916

2017
public abstract class ElasticInferenceServiceModel extends Model {
@@ -49,7 +46,4 @@ public ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings(
4946
public ElasticInferenceServiceComponents elasticInferenceServiceComponents() {
5047
return elasticInferenceServiceComponents;
5148
}
52-
53-
public abstract ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map<String, Object> taskSettings);
54-
5549
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
import java.util.List;
1414

15+
/**
16+
* Encapsulates settings using {@link Setting}. This does not represent service settings that are persisted
17+
* via {@link org.elasticsearch.inference.ServiceSettings}.
18+
*/
1519
public class ElasticInferenceServiceSettings {
1620

1721
@Deprecated

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
2929

30-
public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceModel {
30+
public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel {
3131

3232
private final URI uri;
3333

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.completion;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.EmptySecretSettings;
12+
import org.elasticsearch.inference.EmptyTaskSettings;
13+
import org.elasticsearch.inference.ModelConfigurations;
14+
import org.elasticsearch.inference.ModelSecrets;
15+
import org.elasticsearch.inference.SecretSettings;
16+
import org.elasticsearch.inference.TaskSettings;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.inference.UnifiedCompletionRequest;
19+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
20+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
21+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;
22+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
23+
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
24+
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
25+
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
26+
27+
import java.net.URI;
28+
import java.net.URISyntaxException;
29+
import java.util.Locale;
30+
import java.util.Map;
31+
import java.util.Objects;
32+
33+
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
34+
35+
public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel {
36+
37+
public static ElasticInferenceServiceCompletionModel of(ElasticInferenceServiceCompletionModel model, UnifiedCompletionRequest request) {
38+
var originalModelServiceSettings = model.getServiceSettings();
39+
var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings(
40+
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()),
41+
originalModelServiceSettings.rateLimitSettings()
42+
);
43+
44+
return new ElasticInferenceServiceCompletionModel(
45+
model.getInferenceEntityId(),
46+
model.getTaskType(),
47+
model.getConfigurations().getService(),
48+
overriddenServiceSettings,
49+
model.getTaskSettings(),
50+
model.getSecretSettings()
51+
);
52+
}
53+
54+
private final URI uri;
55+
56+
public ElasticInferenceServiceCompletionModel(
57+
String inferenceEntityId,
58+
TaskType taskType,
59+
String service,
60+
Map<String, Object> serviceSettings,
61+
Map<String, Object> taskSettings,
62+
Map<String, Object> secrets,
63+
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
64+
ConfigurationParseContext context
65+
) {
66+
this(
67+
inferenceEntityId,
68+
taskType,
69+
service,
70+
ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(serviceSettings, context),
71+
EmptyTaskSettings.INSTANCE,
72+
EmptySecretSettings.INSTANCE,
73+
elasticInferenceServiceComponents
74+
);
75+
}
76+
77+
public ElasticInferenceServiceCompletionModel(
78+
ElasticInferenceServiceCompletionModel model,
79+
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings
80+
) {
81+
super(model, serviceSettings);
82+
83+
try {
84+
this.uri = createUri();
85+
} catch (URISyntaxException e) {
86+
throw new RuntimeException(e);
87+
}
88+
}
89+
90+
ElasticInferenceServiceCompletionModel(
91+
String inferenceEntityId,
92+
TaskType taskType,
93+
String service,
94+
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings,
95+
@Nullable TaskSettings taskSettings,
96+
@Nullable SecretSettings secretSettings,
97+
ElasticInferenceServiceComponents elasticInferenceServiceComponents
98+
) {
99+
super(
100+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
101+
new ModelSecrets(secretSettings),
102+
serviceSettings,
103+
elasticInferenceServiceComponents
104+
);
105+
106+
try {
107+
this.uri = createUri();
108+
} catch (URISyntaxException e) {
109+
throw new RuntimeException(e);
110+
}
111+
}
112+
113+
@Override
114+
public ElasticInferenceServiceSparseEmbeddingsServiceSettings getServiceSettings() {
115+
return (ElasticInferenceServiceSparseEmbeddingsServiceSettings) super.getServiceSettings();
116+
}
117+
118+
public URI uri() {
119+
return uri;
120+
}
121+
122+
private URI createUri() throws URISyntaxException {
123+
String modelId = getServiceSettings().modelId();
124+
// String modelIdUriPath;
125+
//
126+
// switch (modelId) {
127+
// case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2";
128+
// default -> throw new IllegalArgumentException(
129+
// String.format(Locale.ROOT, "Unsupported model for %s [%s]", ELASTIC_INFERENCE_SERVICE_IDENTIFIER, modelId)
130+
// );
131+
// }
132+
133+
// TODO what is the url?
134+
// return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/completion/" + modelIdUriPath);
135+
136+
return
137+
}
138+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.completion;
9+
10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.ValidationException;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.inference.ModelConfigurations;
16+
import org.elasticsearch.inference.ServiceSettings;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
19+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
20+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings;
21+
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
22+
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
23+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
24+
25+
import java.io.IOException;
26+
import java.util.Map;
27+
import java.util.Objects;
28+
29+
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
30+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
31+
32+
public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXContentObject
33+
implements
34+
ServiceSettings,
35+
ElasticInferenceServiceRateLimitServiceSettings {
36+
37+
public static final String NAME = "elastic_inference_service_completion_service_settings";
38+
39+
// TODO what value do we put here?
40+
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1_000);
41+
42+
public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
43+
ValidationException validationException = new ValidationException();
44+
45+
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
46+
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
47+
map,
48+
DEFAULT_RATE_LIMIT_SETTINGS,
49+
validationException,
50+
ElasticInferenceService.NAME,
51+
context
52+
);
53+
54+
if (modelId != null && ElserModels.isValidEisModel(modelId) == false) {
55+
validationException.addValidationError("unknown ELSER model id [" + modelId + "]");
56+
}
57+
58+
if (validationException.validationErrors().isEmpty() == false) {
59+
throw validationException;
60+
}
61+
62+
return new ElasticInferenceServiceCompletionServiceSettings(modelId, rateLimitSettings);
63+
}
64+
65+
private final String modelId;
66+
private final RateLimitSettings rateLimitSettings;
67+
68+
public ElasticInferenceServiceCompletionServiceSettings(String modelId, RateLimitSettings rateLimitSettings) {
69+
this.modelId = Objects.requireNonNull(modelId);
70+
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
71+
}
72+
73+
public ElasticInferenceServiceCompletionServiceSettings(StreamInput in) throws IOException {
74+
this.modelId = in.readString();
75+
this.rateLimitSettings = new RateLimitSettings(in);
76+
}
77+
78+
@Override
79+
public String getWriteableName() {
80+
return NAME;
81+
}
82+
83+
public String modelId() {
84+
return modelId;
85+
}
86+
87+
@Override
88+
public RateLimitSettings rateLimitSettings() {
89+
return rateLimitSettings;
90+
}
91+
92+
@Override
93+
public TransportVersion getMinimalSupportedVersion() {
94+
return TransportVersions.V_8_16_0;
95+
}
96+
97+
@Override
98+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
99+
builder.startObject();
100+
101+
toXContentFragmentOfExposedFields(builder, params);
102+
103+
builder.endObject();
104+
105+
return builder;
106+
}
107+
108+
@Override
109+
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
110+
builder.field(MODEL_ID, modelId);
111+
rateLimitSettings.toXContent(builder, params);
112+
113+
return builder;
114+
}
115+
116+
@Override
117+
public void writeTo(StreamOutput out) throws IOException {
118+
out.writeString(modelId);
119+
rateLimitSettings.writeTo(out);
120+
}
121+
122+
@Override
123+
public boolean equals(Object object) {
124+
if (this == object) return true;
125+
if (object == null || getClass() != object.getClass()) return false;
126+
ElasticInferenceServiceCompletionServiceSettings that = (ElasticInferenceServiceCompletionServiceSettings) object;
127+
return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
128+
}
129+
130+
@Override
131+
public int hashCode() {
132+
return Objects.hash(modelId, rateLimitSettings);
133+
}
134+
}

0 commit comments

Comments
 (0)