Skip to content

Commit 6600017

Browse files
committed
WIP
1 parent c3d53a8 commit 6600017

File tree

7 files changed

+186
-14
lines changed

7 files changed

+186
-14
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ static TransportVersion def(int id) {
243243
public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDED = def(8_767_00_0);
244244
public static final TransportVersion OPT_IN_ESQL_CCS_EXECUTION_INFO = def(8_768_00_0);
245245
public static final TransportVersion QUERY_RULE_TEST_API = def(8_769_00_0);
246+
public static final TransportVersion ML_INFERENCE_ATTACH_TO_EXISTSING_DEPLOYMENT = def(8_770_00_0);
246247

247248
/*
248249
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
4747

4848
private final ModelRegistry modelRegistry;
4949
private final InferenceServiceRegistry serviceRegistry;
50-
private static final Logger logger = LogManager.getLogger(TransportDeleteInferenceEndpointAction.class);
5150
private final Executor executor;
5251

5352
@Inject
@@ -118,7 +117,11 @@ private void doExecuteForked(
118117

119118
var service = serviceRegistry.getService(unparsedModel.service());
120119
if (service.isPresent()) {
121-
service.get().stop(request.getInferenceEndpointId(), listener);
120+
if (service.get().isInClusterService()) {
121+
// check for other models using this deployment
122+
} else {
123+
service.get().stop(request.getInferenceEndpointId(), listener);
124+
}
122125
} else {
123126
listener.onFailure(
124127
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ public void start(Model model, ActionListener<Boolean> finalListener) {
120120

121121
@Override
122122
public void stop(String inferenceEntityId, ActionListener<Boolean> listener) {
123+
// TODO check if other inference endpoints are using this deployment
123124
var request = new StopTrainedModelDeploymentAction.Request(inferenceEntityId);
124125
request.setForce(true);
125126
client.execute(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,8 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
9191
public String toString() {
9292
return Strings.toString(this.getConfigurations());
9393
}
94+
95+
public String mlNodeDeploymentId() {
96+
return internalServiceSettings.getDeploymentId() == null ? getInferenceEntityId() : internalServiceSettings.getDeploymentId();
97+
}
9498
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 124 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,20 @@
3333
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
3434
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3535
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
36+
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
3637
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
3738
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
39+
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
3840
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
3941
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
4042
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
4143
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
44+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
45+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
4246
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
47+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig;
4348
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
49+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
4450
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
4551
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
4652
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
@@ -52,6 +58,7 @@
5258
import java.util.EnumSet;
5359
import java.util.List;
5460
import java.util.Map;
61+
import java.util.Optional;
5562
import java.util.Set;
5663
import java.util.function.Consumer;
5764
import java.util.function.Function;
@@ -134,7 +141,10 @@ public void parseRequestConfig(
134141
throwIfNotEmptyMap(config, name());
135142

136143
String modelId = (String) serviceSettingsMap.get(ElasticsearchInternalServiceSettings.MODEL_ID);
137-
if (modelId == null) {
144+
String deploymentId = (String) serviceSettingsMap.get(ElasticsearchInternalServiceSettings.DEPLOYMENT_ID);
145+
if (deploymentId != null) {
146+
validateAgainstDeployment(modelId, deploymentId, taskType, ) // TODO create model
147+
} else if (modelId == null) {
138148
if (OLD_ELSER_SERVICE_NAME.equals(serviceName)) {
139149
// TODO complete deprecation of null model ID
140150
// throw new ValidationException().addValidationError("Error parsing request config, model id is missing");
@@ -215,6 +225,8 @@ private void customElandCase(
215225
+ "]. You may need to load it into the cluster using eland."
216226
);
217227
} else {
228+
throwIfUnsupportedTaskType(modelId, taskType, response.getResources().results().get(0).getInferenceConfig());
229+
218230
var model = createCustomElandModel(
219231
inferenceEntityId,
220232
taskType,
@@ -553,7 +565,7 @@ public void inferTextEmbedding(
553565
ActionListener<InferenceServiceResults> listener
554566
) {
555567
var request = buildInferenceRequest(
556-
model.getConfigurations().getInferenceEntityId(),
568+
model.mlNodeDeploymentId(),
557569
TextEmbeddingConfigUpdate.EMPTY_INSTANCE,
558570
inputs,
559571
inputType,
@@ -579,7 +591,7 @@ public void inferSparseEmbedding(
579591
ActionListener<InferenceServiceResults> listener
580592
) {
581593
var request = buildInferenceRequest(
582-
model.getConfigurations().getInferenceEntityId(),
594+
model.mlNodeDeploymentId(),
583595
TextExpansionConfigUpdate.EMPTY_UPDATE,
584596
inputs,
585597
inputType,
@@ -607,7 +619,7 @@ public void inferRerank(
607619
ActionListener<InferenceServiceResults> listener
608620
) {
609621
var request = buildInferenceRequest(
610-
model.getConfigurations().getInferenceEntityId(),
622+
model.mlNodeDeploymentId(),
611623
new TextSimilarityConfigUpdate(query),
612624
inputs,
613625
inputType,
@@ -681,7 +693,7 @@ public void chunkedInfer(
681693

682694
for (var batch : batchedRequests) {
683695
var inferenceRequest = buildInferenceRequest(
684-
model.getConfigurations().getInferenceEntityId(),
696+
esModel.mlNodeDeploymentId(),
685697
EmptyConfigUpdate.INSTANCE,
686698
batch.batch().inputs(),
687699
inputType,
@@ -858,4 +870,111 @@ static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSetting
858870
);
859871
};
860872
}
873+
874+
875+
private void validateAgainstDeployment(
876+
String modelId,
877+
String deploymentId,
878+
TaskType taskType,
879+
ActionListener<ElasticsearchInternalServiceSettings.Builder> listener
880+
) {
881+
getDeployment(deploymentId, listener.delegateFailureAndWrap((l, response) -> {
882+
if (response.isPresent()) {
883+
if (modelId.equals(response.get().getModelId()) == false) {
884+
listener.onFailure(
885+
new ElasticsearchStatusException(
886+
"Deployment [{}] uses model [{}] which does not match the model [{}] in the request.",
887+
RestStatus.BAD_REQUEST, // TODO better message
888+
deploymentId,
889+
response.get().getModelId(),
890+
modelId
891+
)
892+
);
893+
return;
894+
}
895+
896+
var updatedSettings = new ElasticsearchInternalServiceSettings.Builder().setNumAllocations(
897+
response.get().getNumberOfAllocations()
898+
)
899+
.setNumThreads(response.get().getThreadsPerAllocation())
900+
.setAdaptiveAllocationsSettings(response.get().getAdaptiveAllocationsSettings())
901+
.setDeploymentId(deploymentId)
902+
.setModelId(modelId);
903+
904+
checkTaskTypeForMlNodeModel(
905+
response.get().getModelId(),
906+
taskType,
907+
l.delegateFailureAndWrap((l2, compatibleTaskType) -> { l2.onResponse(updatedSettings); })
908+
);
909+
}
910+
}));
911+
}
912+
913+
private void getDeployment(String deploymentId, ActionListener<Optional<AssignmentStats>> listener) {
914+
client.execute(
915+
GetTrainedModelsStatsAction.INSTANCE,
916+
new GetTrainedModelsStatsAction.Request(deploymentId),
917+
listener.delegateFailureAndWrap((l, response) -> {
918+
l.onResponse(
919+
response.getResources()
920+
.results()
921+
.stream()
922+
.filter(s -> s.getDeploymentStats() != null && s.getDeploymentStats().getDeploymentId().equals(deploymentId))
923+
.map(GetTrainedModelsStatsAction.Response.TrainedModelStats::getDeploymentStats)
924+
.findFirst()
925+
);
926+
})
927+
);
928+
}
929+
930+
private void checkTaskTypeForMlNodeModel(String modelId, TaskType taskType, ActionListener<Boolean> listener) {
931+
client.execute(
932+
GetTrainedModelsAction.INSTANCE,
933+
new GetTrainedModelsAction.Request(modelId),
934+
listener.delegateFailureAndWrap((l, response) -> {
935+
if (response.getResources().results().isEmpty()) {
936+
l.onFailure(new IllegalStateException("this shouldn't happen"));
937+
return;
938+
}
939+
940+
var inferenceConfig = response.getResources().results().get(0).getInferenceConfig();
941+
throwIfUnsupportedTaskType(modelId, taskType, inferenceConfig);
942+
l.onResponse(Boolean.TRUE);
943+
})
944+
);
945+
}
946+
947+
static void throwIfUnsupportedTaskType(String modelId, TaskType taskType, InferenceConfig inferenceConfig) {
948+
var deploymentTaskType = inferenceConfigToTaskType(inferenceConfig);
949+
if (deploymentTaskType == null) {
950+
throw new ElasticsearchStatusException(
951+
"Deployed model [{}] has type [{}] which does not map to any supported task types",
952+
RestStatus.BAD_REQUEST,
953+
modelId,
954+
inferenceConfig.getWriteableName()
955+
);
956+
}
957+
if (deploymentTaskType != taskType) {
958+
throw new ElasticsearchStatusException(
959+
"Deployed model [{}] with type [{}] does not match the requested task type [{}]",
960+
RestStatus.BAD_REQUEST,
961+
modelId,
962+
inferenceConfig.getWriteableName(),
963+
taskType
964+
);
965+
}
966+
967+
}
968+
969+
static TaskType inferenceConfigToTaskType(InferenceConfig config) {
970+
if (config instanceof TextExpansionConfig) {
971+
return TaskType.SPARSE_EMBEDDING;
972+
} else if (config instanceof TextEmbeddingConfig) {
973+
return TaskType.TEXT_EMBEDDING;
974+
} else if (config instanceof TextSimilarityConfig) {
975+
return TaskType.RERANK;
976+
} else {
977+
return null;
978+
}
979+
}
861980
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
3636
public static final String NUM_ALLOCATIONS = "num_allocations";
3737
public static final String NUM_THREADS = "num_threads";
3838
public static final String MODEL_ID = "model_id";
39+
public static final String DEPLOYMENT_ID = "deployment_id";
3940
public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations";
4041

4142
private final Integer numAllocations;
4243
private final int numThreads;
4344
private final String modelId;
4445
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
46+
private final String deploymentId;
4547

4648
public static ElasticsearchInternalServiceSettings fromPersistedMap(Map<String, Object> map) {
4749
return fromRequestMap(map).build();
@@ -95,12 +97,15 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap(
9597
);
9698
}
9799

100+
String deploymentId = extractOptionalString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
101+
98102
// if an error occurred while parsing, we'll set these to an invalid value, so we don't accidentally get a
99103
// null pointer when doing unboxing
100104
return new ElasticsearchInternalServiceSettings.Builder().setNumAllocations(numAllocations)
101105
.setNumThreads(Objects.requireNonNullElse(numThreads, FAILED_INT_PARSE_VALUE))
102106
.setModelId(modelId)
103-
.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings);
107+
.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings)
108+
.setDeploymentId(deploymentId);
104109
}
105110

106111
public ElasticsearchInternalServiceSettings(
@@ -113,13 +118,29 @@ public ElasticsearchInternalServiceSettings(
113118
this.numThreads = numThreads;
114119
this.modelId = Objects.requireNonNull(modelId);
115120
this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
121+
this.deploymentId = null;
122+
}
123+
124+
public ElasticsearchInternalServiceSettings(
125+
Integer numAllocations,
126+
int numThreads,
127+
String modelId,
128+
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
129+
String deploymentId
130+
) {
131+
this.numAllocations = numAllocations;
132+
this.numThreads = numThreads;
133+
this.modelId = Objects.requireNonNull(modelId);
134+
this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
135+
this.deploymentId = deploymentId;
116136
}
117137

118138
protected ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSettings other) {
119139
this.numAllocations = other.numAllocations;
120140
this.numThreads = other.numThreads;
121141
this.modelId = other.modelId;
122142
this.adaptiveAllocationsSettings = other.adaptiveAllocationsSettings;
143+
this.deploymentId = other.deploymentId;
123144
}
124145

125146
/**
@@ -145,6 +166,9 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
145166
this.adaptiveAllocationsSettings = in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)
146167
? in.readOptionalWriteable(AdaptiveAllocationsSettings::new)
147168
: null;
169+
this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ATTACH_TO_EXISTSING_DEPLOYMENT)
170+
? in.readOptionalString()
171+
: null;
148172
}
149173

150174
@Override
@@ -159,6 +183,9 @@ public void writeTo(StreamOutput out) throws IOException {
159183
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) {
160184
out.writeOptionalWriteable(getAdaptiveAllocationsSettings());
161185
}
186+
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ATTACH_TO_EXISTSING_DEPLOYMENT)) {
187+
out.writeOptionalString(deploymentId);
188+
}
162189
}
163190

164191
@Override
@@ -178,6 +205,10 @@ public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() {
178205
return adaptiveAllocationsSettings;
179206
}
180207

208+
public String getDeploymentId() {
209+
return deploymentId;
210+
}
211+
181212
@Override
182213
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
183214
builder.startObject();
@@ -195,6 +226,9 @@ protected void addInternalSettingsToXContent(XContentBuilder builder, Params par
195226
if (adaptiveAllocationsSettings != null) {
196227
builder.field(ADAPTIVE_ALLOCATIONS, adaptiveAllocationsSettings);
197228
}
229+
if (deploymentId != null) {
230+
builder.field(DEPLOYMENT_ID, deploymentId);
231+
}
198232
}
199233

200234
@Override
@@ -217,9 +251,10 @@ public static class Builder {
217251
private int numThreads;
218252
private String modelId;
219253
private AdaptiveAllocationsSettings adaptiveAllocationsSettings;
254+
private String deploymentId;
220255

221256
public ElasticsearchInternalServiceSettings build() {
222-
return new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
257+
return new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, deploymentId);
223258
}
224259

225260
public Builder setNumAllocations(Integer numAllocations) {
@@ -237,6 +272,11 @@ public Builder setModelId(String modelId) {
237272
return this;
238273
}
239274

275+
public Builder setDeploymentId(String deploymentId) {
276+
this.deploymentId = deploymentId;
277+
return this;
278+
}
279+
240280
public Builder setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
241281
this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
242282
return this;
@@ -266,11 +306,12 @@ public boolean equals(Object o) {
266306
return Objects.equals(numAllocations, that.numAllocations)
267307
&& numThreads == that.numThreads
268308
&& Objects.equals(modelId, that.modelId)
269-
&& Objects.equals(adaptiveAllocationsSettings, that.adaptiveAllocationsSettings);
309+
&& Objects.equals(adaptiveAllocationsSettings, that.adaptiveAllocationsSettings)
310+
&& Objects.equals(deploymentId, that.deploymentId);
270311
}
271312

272313
@Override
273314
public int hashCode() {
274-
return Objects.hash(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
315+
return Objects.hash(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, deploymentId);
275316
}
276317
}

0 commit comments

Comments
 (0)