Skip to content

Commit e93fcac

Browse files
[ML] Adding missing endpoint information to cluster state (elastic#138934)
* Fixing out of sync endpoints * Cleaning up * Addressing feedback * Adding comments and extracting method
1 parent dad01a6 commit e93fcac

File tree

2 files changed

+328
-38
lines changed

2 files changed

+328
-38
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 180 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.elasticsearch.index.engine.VersionConflictEngineException;
2727
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
2828
import org.elasticsearch.index.query.QueryBuilders;
29+
import org.elasticsearch.inference.EmptySecretSettings;
30+
import org.elasticsearch.inference.EmptyTaskSettings;
2931
import org.elasticsearch.inference.InferenceService;
3032
import org.elasticsearch.inference.InferenceServiceExtension;
3133
import org.elasticsearch.inference.MinimalServiceSettings;
@@ -48,13 +50,19 @@
4850
import org.elasticsearch.threadpool.ThreadPool;
4951
import org.elasticsearch.xcontent.ToXContentObject;
5052
import org.elasticsearch.xcontent.XContentBuilder;
53+
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
5154
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests;
5255
import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse;
5356
import org.elasticsearch.xpack.inference.InferenceIndex;
5457
import org.elasticsearch.xpack.inference.InferenceSecretsIndex;
5558
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
59+
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
5660
import org.elasticsearch.xpack.inference.model.TestModel;
5761
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
62+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
63+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
64+
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
65+
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
5866
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
5967
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
6068
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests;
@@ -84,6 +92,7 @@
8492
import static org.hamcrest.CoreMatchers.is;
8593
import static org.hamcrest.Matchers.containsString;
8694
import static org.hamcrest.Matchers.empty;
95+
import static org.hamcrest.Matchers.hasItem;
8796
import static org.hamcrest.Matchers.hasSize;
8897
import static org.hamcrest.Matchers.instanceOf;
8998
import static org.hamcrest.Matchers.not;
@@ -104,6 +113,11 @@ public void createComponents() {
104113
modelRegistry.clearDefaultIds();
105114
}
106115

116+
@Override
117+
protected boolean resetNodeAfterTest() {
118+
return true;
119+
}
120+
107121
@Override
108122
protected Collection<Class<? extends Plugin>> getPlugins() {
109123
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
@@ -673,6 +687,16 @@ public void testStoreModels_StoresMultipleInferenceEndpoints() {
673687
assertModelAndMinimalSettingsWithSecrets(modelRegistry, model2, secrets);
674688
}
675689

690+
private static void assertModelAndMinimalSettingsWithoutSecrets(ModelRegistry registry, Model model) {
691+
assertMinimalServiceSettings(registry, model);
692+
693+
var listener = new PlainActionFuture<UnparsedModel>();
694+
registry.getModel(model.getInferenceEntityId(), listener);
695+
696+
var storedModel = listener.actionGet(TimeValue.THIRTY_SECONDS);
697+
assertModelWithoutSecrets(storedModel, model);
698+
}
699+
676700
private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry registry, Model model, String secrets) {
677701
assertMinimalServiceSettings(registry, model);
678702

@@ -684,16 +708,20 @@ private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry regis
684708
}
685709

686710
private static void assertModel(UnparsedModel model, Model expected, String secrets) {
687-
assertThat(model.inferenceEntityId(), Matchers.is(expected.getInferenceEntityId()));
688-
assertThat(model.service(), Matchers.is(expected.getConfigurations().getService()));
689-
assertThat(model.taskType(), Matchers.is(expected.getConfigurations().getTaskType()));
711+
assertModelWithoutSecrets(model, expected);
690712
assertThat(model.secrets().keySet(), hasSize(1));
691713
assertThat(model.secrets().get("secret_settings"), instanceOf(Map.class));
692714
@SuppressWarnings("unchecked")
693715
var secretSettings = (Map<String, Object>) model.secrets().get("secret_settings");
694716
assertThat(secretSettings.get("api_key"), Matchers.is(secrets));
695717
}
696718

719+
private static void assertModelWithoutSecrets(UnparsedModel model, Model expected) {
720+
assertThat(model.inferenceEntityId(), Matchers.is(expected.getInferenceEntityId()));
721+
assertThat(model.service(), Matchers.is(expected.getConfigurations().getService()));
722+
assertThat(model.taskType(), Matchers.is(expected.getConfigurations().getTaskType()));
723+
}
724+
697725
public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflictExists() {
698726
var secrets = "secret";
699727

@@ -873,6 +901,154 @@ public void testStoreModels_OnFailure_RemovesPartialWritesOfInferenceEndpoint()
873901
assertIndicesContainExpectedDocsCount(model3, 2);
874902
}
875903

904+
public void testStoreModels_Adds_OutOfSyncEndpoints_ToClusterState() {
905+
var inferenceId1 = "1";
906+
907+
var model = new ElasticInferenceServiceSparseEmbeddingsModel(
908+
inferenceId1,
909+
TaskType.SPARSE_EMBEDDING,
910+
ElasticInferenceService.NAME,
911+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings("model", null),
912+
EmptyTaskSettings.INSTANCE,
913+
EmptySecretSettings.INSTANCE,
914+
new ElasticInferenceServiceComponents("url"),
915+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
916+
);
917+
918+
storeModelDirectlyInIndexWithoutRegistry(model);
919+
920+
assertThat(modelRegistry.getInferenceIds(), not(hasItem(inferenceId1)));
921+
922+
var storeListener = new PlainActionFuture<List<ModelStoreResponse>>();
923+
modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS);
924+
925+
var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS);
926+
assertThat(response.size(), is(1));
927+
assertThat(response.get(0).inferenceId(), is(model.getInferenceEntityId()));
928+
assertThat(response.get(0).status(), is(RestStatus.CONFLICT));
929+
assertTrue(response.get(0).failed());
930+
931+
// Storing the model fails because it already exists, but the registry should now be aware of the inference id in
932+
// cluster state
933+
var cause = response.get(0).failureCause();
934+
assertThat(cause, instanceOf(VersionConflictEngineException.class));
935+
assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists"));
936+
937+
assertIndicesContainExpectedDocsCount(model, 2);
938+
assertMinimalServiceSettings(modelRegistry, model);
939+
940+
var getModelWithSecretsListener = new PlainActionFuture<UnparsedModel>();
941+
modelRegistry.getModelWithSecrets(model.getInferenceEntityId(), getModelWithSecretsListener);
942+
943+
var unparsedModel = getModelWithSecretsListener.actionGet(TimeValue.THIRTY_SECONDS);
944+
945+
assertThat(unparsedModel.inferenceEntityId(), is(model.getInferenceEntityId()));
946+
assertThat(unparsedModel.service(), is(model.getConfigurations().getService()));
947+
assertThat(unparsedModel.taskType(), is(model.getConfigurations().getTaskType()));
948+
949+
assertThat(modelRegistry.getInferenceIds(), hasItem(inferenceId1));
950+
}
951+
952+
private void storeModelDirectlyInIndexWithoutRegistry(Model model) {
953+
var listener = new PlainActionFuture<BulkResponse>();
954+
955+
client().prepareBulk()
956+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
957+
.add(
958+
ModelRegistry.createIndexRequestBuilder(
959+
model.getInferenceEntityId(),
960+
InferenceIndex.INDEX_NAME,
961+
model.getConfigurations(),
962+
false,
963+
client()
964+
)
965+
)
966+
.add(
967+
ModelRegistry.createIndexRequestBuilder(
968+
model.getInferenceEntityId(),
969+
InferenceSecretsIndex.INDEX_NAME,
970+
model.getSecrets(),
971+
false,
972+
client()
973+
)
974+
)
975+
.execute(listener);
976+
977+
var bulkResponse = listener.actionGet(TimeValue.THIRTY_SECONDS);
978+
if (bulkResponse.hasFailures()) {
979+
fail("Failed to store model: " + bulkResponse.buildFailureMessage());
980+
}
981+
}
982+
983+
public void testStoreModels_Adds_OutOfSyncEndpoints_ToClusterState_MixedWithSuccessfulStore() {
984+
var inferenceId1 = "1";
985+
986+
var eisModel = new ElasticInferenceServiceSparseEmbeddingsModel(
987+
inferenceId1,
988+
TaskType.SPARSE_EMBEDDING,
989+
ElasticInferenceService.NAME,
990+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings("model", null),
991+
EmptyTaskSettings.INSTANCE,
992+
EmptySecretSettings.INSTANCE,
993+
new ElasticInferenceServiceComponents("url"),
994+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
995+
);
996+
997+
storeModelDirectlyInIndexWithoutRegistry(eisModel);
998+
999+
assertThat(modelRegistry.getInferenceIds(), not(hasItem(inferenceId1)));
1000+
1001+
var testModelId1 = "test-1";
1002+
var testModelId2 = "test-2";
1003+
1004+
// Using these models because the mock inference plugin we use in this test only supports these test services and EIS
1005+
var testModel1 = new TestSparseInferenceServiceExtension.TestSparseModel(
1006+
testModelId1,
1007+
new TestSparseInferenceServiceExtension.TestServiceSettings("model", "hidden_field", false)
1008+
);
1009+
1010+
var testModel2 = new TestSparseInferenceServiceExtension.TestSparseModel(
1011+
testModelId2,
1012+
new TestSparseInferenceServiceExtension.TestServiceSettings("model", "hidden_field", false)
1013+
);
1014+
1015+
var storeListener = new PlainActionFuture<List<ModelStoreResponse>>();
1016+
modelRegistry.storeModels(List.of(eisModel, testModel1, testModel2), storeListener, TimeValue.THIRTY_SECONDS);
1017+
1018+
var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS);
1019+
assertThat(response.size(), is(3));
1020+
assertThat(response.get(0).inferenceId(), is(eisModel.getInferenceEntityId()));
1021+
assertThat(response.get(0).status(), is(RestStatus.CONFLICT));
1022+
assertTrue(response.get(0).failed());
1023+
1024+
assertThat(response.get(1), Matchers.is(new ModelStoreResponse(testModelId1, RestStatus.CREATED, null)));
1025+
assertThat(response.get(2), Matchers.is(new ModelStoreResponse(testModelId2, RestStatus.CREATED, null)));
1026+
1027+
// Storing the model fails because it already exists, but the registry should now be aware of the inference id in
1028+
// cluster state
1029+
var cause = response.get(0).failureCause();
1030+
assertNotNull(cause);
1031+
assertThat(cause, instanceOf(VersionConflictEngineException.class));
1032+
assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists"));
1033+
1034+
assertIndicesContainExpectedDocsCount(eisModel, 2);
1035+
assertMinimalServiceSettings(modelRegistry, eisModel);
1036+
1037+
var getModelWithSecretsListener = new PlainActionFuture<UnparsedModel>();
1038+
modelRegistry.getModelWithSecrets(eisModel.getInferenceEntityId(), getModelWithSecretsListener);
1039+
1040+
var unparsedModel = getModelWithSecretsListener.actionGet(TimeValue.THIRTY_SECONDS);
1041+
1042+
assertThat(unparsedModel.inferenceEntityId(), is(eisModel.getInferenceEntityId()));
1043+
assertThat(unparsedModel.service(), is(eisModel.getConfigurations().getService()));
1044+
assertThat(unparsedModel.taskType(), is(eisModel.getConfigurations().getTaskType()));
1045+
1046+
assertThat(modelRegistry.getInferenceIds(), is(Set.of(inferenceId1, testModelId1, testModelId2)));
1047+
1048+
assertModelAndMinimalSettingsWithoutSecrets(modelRegistry, testModel1);
1049+
assertModelAndMinimalSettingsWithoutSecrets(modelRegistry, testModel2);
1050+
}
1051+
8761052
public void testGetModelNoSecrets() {
8771053
var inferenceId = "1";
8781054

@@ -973,10 +1149,9 @@ private void storeCorruptedModelThenStoreModel(boolean storeSecrets) {
9731149
assertIndicesContainExpectedDocsCount(model, 0);
9741150
}
9751151

976-
private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) {
1152+
private void assertIndicesContainExpectedDocsCount(Model model, int numberOfDocs) {
9771153
SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN)
9781154
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId()))))
979-
.setSize(2)
9801155
.setTrackTotalHits(false)
9811156
.request();
9821157
SearchResponse searchResponse = client().search(modelSearch).actionGet(TimeValue.THIRTY_SECONDS);

0 commit comments

Comments
 (0)