Skip to content

Commit 485232c

Browse files
committed
add integration test
1 parent fd15e68 commit 485232c

File tree

11 files changed

+220
-67
lines changed

11 files changed

+220
-67
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ void chunkedInfer(
129129
/**
130130
* Stop the model deployment.
131131
* The default action does nothing except acknowledge the request (true).
132-
* @param modelId The ID of the model to be stopped
132+
* @param unparsedModel The unparsed model configuration
133133
* @param listener The listener
134134
*/
135-
default void stop(String modelId, ActionListener<Boolean> listener) {
135+
default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
136136
listener.onResponse(true);
137137
}
138138

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.Response;
12+
import org.elasticsearch.core.Strings;
13+
import org.elasticsearch.inference.TaskType;
14+
15+
import java.io.IOException;
16+
import java.util.List;
17+
import java.util.Map;
18+
19+
import static org.hamcrest.Matchers.is;
20+
21+
public class CreateFromDeploymentIT extends CustomElandModelIT {
22+
23+
@SuppressWarnings("unchecked")
24+
public void testAttachToDeployment() throws IOException {
25+
var modelId = "attach_to_deployment";
26+
var deploymentId = "existing_deployment";
27+
28+
createMlNodeTextExpansionModel(modelId);
29+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
30+
assertOkOrCreated(response);
31+
32+
var inferenceId = "inference_on_existing_deployment";
33+
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
34+
var serviceSettings = putModel.get("service_settings");
35+
assertThat(
36+
putModel.toString(),
37+
serviceSettings,
38+
is(Map.of("num_allocations", 1, "num_threads", 1, "model_id", "attach_to_deployment", "deployment_id", "existing_deployment"))
39+
);
40+
41+
var results = infer(inferenceId, List.of("washing machine"));
42+
assertNotNull(results.get("sparse_embedding"));
43+
44+
deleteModel(inferenceId);
45+
// assert deployment not stopped
46+
var stats = (List<Map<String, Object>>) getTrainedModelStats(modelId).get("trained_model_stats");
47+
var deploymentStats = stats.get(0).get("deployment_stats");
48+
assertNotNull(stats.toString(), deploymentStats);
49+
50+
stopMlNodeDeployment(deploymentId);
51+
}
52+
53+
public void testAttachWithModelId() throws IOException {
54+
var modelId = "attach_with_model_id";
55+
var deploymentId = "existing_deployment_with_model_id";
56+
57+
createMlNodeTextExpansionModel(modelId);
58+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
59+
assertOkOrCreated(response);
60+
61+
var inferenceId = "inference_on_existing_deployment";
62+
var putModel = putModel(inferenceId, endpointConfig(modelId, deploymentId), TaskType.SPARSE_EMBEDDING);
63+
var serviceSettings = putModel.get("service_settings");
64+
assertThat(
65+
putModel.toString(),
66+
serviceSettings,
67+
is(
68+
Map.of(
69+
"num_allocations",
70+
1,
71+
"num_threads",
72+
1,
73+
"model_id",
74+
"attach_with_model_id",
75+
"deployment_id",
76+
"existing_deployment_with_model_id"
77+
)
78+
)
79+
);
80+
81+
var results = infer(inferenceId, List.of("washing machine"));
82+
assertNotNull(results.get("sparse_embedding"));
83+
84+
stopMlNodeDeployment(deploymentId);
85+
}
86+
87+
private String endpointConfig(String deploymentId) {
88+
return Strings.format("""
89+
{
90+
"service": "elasticsearch",
91+
"service_settings": {
92+
"deployment_id": "%s"
93+
}
94+
}
95+
""", deploymentId);
96+
}
97+
98+
private String endpointConfig(String modelId, String deploymentId) {
99+
return Strings.format("""
100+
{
101+
"service": "elasticsearch",
102+
"service_settings": {
103+
"model_id": "%s",
104+
"deployment_id": "%s"
105+
}
106+
}
107+
""", modelId, deploymentId);
108+
}
109+
110+
private Response startMlNodeDeploymemnt(String modelId, String deploymentId) throws IOException {
111+
String endPoint = "/_ml/trained_models/"
112+
+ modelId
113+
+ "/deployment/_start?timeout=10s&wait_for=started"
114+
+ "&threads_per_allocation=1"
115+
+ "&number_of_allocations=1";
116+
117+
if (deploymentId != null) {
118+
endPoint = endPoint + "&deployment_id=" + deploymentId;
119+
}
120+
121+
Request request = new Request("POST", endPoint);
122+
return client().performRequest(request);
123+
}
124+
125+
protected void stopMlNodeDeployment(String deploymentId) throws IOException {
126+
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
127+
Request request = new Request("POST", endpoint);
128+
request.addParameter("force", "true");
129+
client().performRequest(request);
130+
}
131+
132+
protected Map<String, Object> getTrainedModelStats(String modelId) throws IOException {
133+
Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats");
134+
return entityAsMap(client().performRequest(request));
135+
}
136+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,15 @@ protected void putModelDefinition(String modelId, String base64EncodedModel, lon
131131
request.setJsonEntity(body);
132132
client().performRequest(request);
133133
}
134+
135+
// Create the model including definition and vocab
136+
protected void createMlNodeTextExpansionModel(String modelId) throws IOException {
137+
createTextExpansionModel(modelId);
138+
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
139+
putVocabulary(
140+
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
141+
modelId
142+
);
143+
}
144+
134145
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ protected void putSemanticText(String endpointId, String searchEndpointId, Strin
207207
}
208208

209209
protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
210-
String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
210+
String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
211211
return putRequest(endpoint, modelConfig);
212212
}
213213

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
package org.elasticsearch.xpack.inference.action;
1111

12-
import org.apache.logging.log4j.LogManager;
13-
import org.apache.logging.log4j.Logger;
1412
import org.elasticsearch.ElasticsearchStatusException;
1513
import org.elasticsearch.action.ActionListener;
1614
import org.elasticsearch.action.ActionRunnable;
@@ -117,7 +115,7 @@ private void doExecuteForked(
117115

118116
var service = serviceRegistry.getService(unparsedModel.service());
119117
if (service.isPresent()) {
120-
service.get().stop(, listener);
118+
service.get().stop(unparsedModel, listener);
121119
} else {
122120
listener.onFailure(
123121
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.inference.InputType;
2323
import org.elasticsearch.inference.Model;
2424
import org.elasticsearch.inference.TaskType;
25+
import org.elasticsearch.inference.UnparsedModel;
2526
import org.elasticsearch.xpack.core.ClientHelper;
2627
import org.elasticsearch.xpack.core.ml.MachineLearningField;
2728
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
@@ -119,16 +120,27 @@ public void start(Model model, ActionListener<Boolean> finalListener) {
119120
}
120121

121122
@Override
122-
public void stop(String deploymentId, ActionListener<Boolean> listener) {
123-
// TODO check if other inference endpoints are using this deployment
124-
// // get the model + deployment id and check if configured by deployment id or has a dedicated deployment
125-
var request = new StopTrainedModelDeploymentAction.Request(inferenceEntityId);
126-
request.setForce(true);
127-
client.execute(
128-
StopTrainedModelDeploymentAction.INSTANCE,
129-
request,
130-
listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse(Boolean.TRUE))
131-
);
123+
public void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
124+
125+
var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
126+
if (model instanceof ElasticsearchInternalModel esModel) {
127+
128+
var serviceSettings = esModel.getServiceSettings();
129+
if (serviceSettings.getDeploymentId() != null) {
130+
// configured with an existing deployment so do not stop it
131+
listener.onResponse(Boolean.TRUE);
132+
}
133+
134+
var request = new StopTrainedModelDeploymentAction.Request(esModel.mlNodeDeploymentId());
135+
request.setForce(true);
136+
client.execute(
137+
StopTrainedModelDeploymentAction.INSTANCE,
138+
request,
139+
listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse(Boolean.TRUE))
140+
);
141+
} else {
142+
listener.onFailure(notElasticsearchModelException(model));
143+
}
132144
}
133145

134146
protected static IllegalStateException notElasticsearchModelException(Model model) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@ public ElasticDeployedModel(
2020
String inferenceEntityId,
2121
TaskType taskType,
2222
String service,
23-
ElserInternalServiceSettings serviceSettings,
24-
ElserMlNodeTaskSettings taskSettings,
23+
ElasticsearchInternalServiceSettings serviceSettings,
2524
ChunkingSettings chunkingSettings
2625
) {
27-
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings);
26+
super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
2827
}
2928

3029
@Override
31-
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(Model model, ActionListener<Boolean> listener) {
30+
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
31+
Model model,
32+
ActionListener<Boolean> listener
33+
) {
3234
return new ActionListener<>() {
3335
@Override
3436
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,7 @@ public void parseRequestConfig(
147147
String deploymentId = (String) serviceSettingsMap.get(ElasticsearchInternalServiceSettings.DEPLOYMENT_ID);
148148
if (deploymentId != null) {
149149
validateAgainstDeployment(modelId, deploymentId, taskType, modelListener.delegateFailureAndWrap((l, settings) -> {
150-
l.onResponse(new ElasticDeployedModel(
151-
inferenceEntityId,
152-
taskType,
153-
NAME,
154-
new ElserInternalServiceSettings(settings.build()),
155-
ElserMlNodeTaskSettings.DEFAULT,
156-
chunkingSettings
157-
));
150+
l.onResponse(new ElasticDeployedModel(inferenceEntityId, taskType, NAME, settings.build(), chunkingSettings));
158151
}));
159152
} else if (modelId == null) {
160153
if (OLD_ELSER_SERVICE_NAME.equals(serviceName)) {
@@ -586,13 +579,7 @@ public void inferSparseEmbedding(
586579
TimeValue timeout,
587580
ActionListener<InferenceServiceResults> listener
588581
) {
589-
var request = buildInferenceRequest(
590-
model.mlNodeDeploymentId(),
591-
TextExpansionConfigUpdate.EMPTY_UPDATE,
592-
inputs,
593-
inputType,
594-
timeout
595-
);
582+
var request = buildInferenceRequest(model.mlNodeDeploymentId(), TextExpansionConfigUpdate.EMPTY_UPDATE, inputs, inputType, timeout);
596583

597584
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
598585
(l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults()))
@@ -614,13 +601,7 @@ public void inferRerank(
614601
Map<String, Object> requestTaskSettings,
615602
ActionListener<InferenceServiceResults> listener
616603
) {
617-
var request = buildInferenceRequest(
618-
model.mlNodeDeploymentId(),
619-
new TextSimilarityConfigUpdate(query),
620-
inputs,
621-
inputType,
622-
timeout
623-
);
604+
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
624605

625606
var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings();
626607
var requestSettings = CustomElandRerankTaskSettings.fromMap(requestTaskSettings);
@@ -917,7 +898,6 @@ static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSetting
917898
};
918899
}
919900

920-
921901
private void validateAgainstDeployment(
922902
String modelId,
923903
String deploymentId,
@@ -926,7 +906,7 @@ private void validateAgainstDeployment(
926906
) {
927907
getDeployment(deploymentId, listener.delegateFailureAndWrap((l, response) -> {
928908
if (response.isPresent()) {
929-
if (modelId.equals(response.get().getModelId()) == false) {
909+
if (modelId != null && modelId.equals(response.get().getModelId()) == false) {
930910
listener.onFailure(
931911
new ElasticsearchStatusException(
932912
"Deployment [{}] uses model [{}] which does not match the model [{}] in the request.",
@@ -940,18 +920,16 @@ private void validateAgainstDeployment(
940920
}
941921

942922
var updatedSettings = new ElasticsearchInternalServiceSettings.Builder().setNumAllocations(
943-
response.get().getNumberOfAllocations()
944-
)
923+
response.get().getNumberOfAllocations()
924+
)
945925
.setNumThreads(response.get().getThreadsPerAllocation())
946926
.setAdaptiveAllocationsSettings(response.get().getAdaptiveAllocationsSettings())
947927
.setDeploymentId(deploymentId)
948-
.setModelId(modelId);
928+
.setModelId(response.get().getModelId());
949929

950-
checkTaskTypeForMlNodeModel(
951-
response.get().getModelId(),
952-
taskType,
953-
l.delegateFailureAndWrap((l2, compatibleTaskType) -> { l2.onResponse(updatedSettings); })
954-
);
930+
checkTaskTypeForMlNodeModel(response.get().getModelId(), taskType, l.delegateFailureAndWrap((l2, compatibleTaskType) -> {
931+
l2.onResponse(updatedSettings);
932+
}));
955933
}
956934
}));
957935
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ public ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSettings
153153
this.numThreads = other.numThreads;
154154
this.modelId = other.modelId;
155155
this.adaptiveAllocationsSettings = other.adaptiveAllocationsSettings;
156+
this.deploymentId = other.deploymentId;
156157
}
157158

158159
public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,18 @@ public class ElasticsearchInternalServiceSettingsTests extends AbstractWireSeria
2222

2323
public static ElasticsearchInternalServiceSettings validInstance(String modelId) {
2424
boolean useAdaptive = randomBoolean();
25+
var deploymentId = randomBoolean() ? null : randomAlphaOfLength(5);
2526
if (useAdaptive) {
2627
var adaptive = new AdaptiveAllocationsSettings(true, 1, randomIntBetween(2, 8));
27-
return new ElasticsearchInternalServiceSettings(randomBoolean() ? 1 : null, randomIntBetween(1, 16), modelId, adaptive);
28+
return new ElasticsearchInternalServiceSettings(
29+
randomBoolean() ? 1 : null,
30+
randomIntBetween(1, 16),
31+
modelId,
32+
adaptive,
33+
deploymentId
34+
);
2835
} else {
29-
return new ElasticsearchInternalServiceSettings(randomIntBetween(1, 10), randomIntBetween(1, 16), modelId, null);
36+
return new ElasticsearchInternalServiceSettings(randomIntBetween(1, 10), randomIntBetween(1, 16), modelId, null, deploymentId);
3037
}
3138
}
3239

0 commit comments

Comments
 (0)