Skip to content

Commit f16bdf5

Browse files
Fixing out of sync endpoints
1 parent c12ecb2 commit f16bdf5

File tree

2 files changed

+303
-33
lines changed

2 files changed

+303
-33
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 180 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.elasticsearch.index.engine.VersionConflictEngineException;
2727
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
2828
import org.elasticsearch.index.query.QueryBuilders;
29+
import org.elasticsearch.inference.EmptySecretSettings;
30+
import org.elasticsearch.inference.EmptyTaskSettings;
2931
import org.elasticsearch.inference.InferenceService;
3032
import org.elasticsearch.inference.InferenceServiceExtension;
3133
import org.elasticsearch.inference.MinimalServiceSettings;
@@ -48,13 +50,19 @@
4850
import org.elasticsearch.threadpool.ThreadPool;
4951
import org.elasticsearch.xcontent.ToXContentObject;
5052
import org.elasticsearch.xcontent.XContentBuilder;
53+
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
5154
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests;
5255
import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse;
5356
import org.elasticsearch.xpack.inference.InferenceIndex;
5457
import org.elasticsearch.xpack.inference.InferenceSecretsIndex;
5558
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
59+
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
5660
import org.elasticsearch.xpack.inference.model.TestModel;
5761
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
62+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
63+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
64+
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
65+
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
5866
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
5967
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
6068
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests;
@@ -84,6 +92,7 @@
8492
import static org.hamcrest.CoreMatchers.is;
8593
import static org.hamcrest.Matchers.containsString;
8694
import static org.hamcrest.Matchers.empty;
95+
import static org.hamcrest.Matchers.hasItem;
8796
import static org.hamcrest.Matchers.hasSize;
8897
import static org.hamcrest.Matchers.instanceOf;
8998
import static org.hamcrest.Matchers.not;
@@ -104,6 +113,11 @@ public void createComponents() {
104113
modelRegistry.clearDefaultIds();
105114
}
106115

116+
@Override
117+
protected boolean resetNodeAfterTest() {
118+
return true;
119+
}
120+
107121
@Override
108122
protected Collection<Class<? extends Plugin>> getPlugins() {
109123
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
@@ -673,6 +687,16 @@ public void testStoreModels_StoresMultipleInferenceEndpoints() {
673687
assertModelAndMinimalSettingsWithSecrets(modelRegistry, model2, secrets);
674688
}
675689

690+
private static void assertModelAndMinimalSettingsWithoutSecrets(ModelRegistry registry, Model model) {
691+
assertMinimalServiceSettings(registry, model);
692+
693+
var listener = new PlainActionFuture<UnparsedModel>();
694+
registry.getModel(model.getInferenceEntityId(), listener);
695+
696+
var storedModel = listener.actionGet(TimeValue.THIRTY_SECONDS);
697+
assertModelWithoutSecrets(storedModel, model);
698+
}
699+
676700
private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry registry, Model model, String secrets) {
677701
assertMinimalServiceSettings(registry, model);
678702

@@ -684,16 +708,20 @@ private static void assertModelAndMinimalSettingsWithSecrets(ModelRegistry regis
684708
}
685709

686710
private static void assertModel(UnparsedModel model, Model expected, String secrets) {
687-
assertThat(model.inferenceEntityId(), Matchers.is(expected.getInferenceEntityId()));
688-
assertThat(model.service(), Matchers.is(expected.getConfigurations().getService()));
689-
assertThat(model.taskType(), Matchers.is(expected.getConfigurations().getTaskType()));
711+
assertModelWithoutSecrets(model, expected);
690712
assertThat(model.secrets().keySet(), hasSize(1));
691713
assertThat(model.secrets().get("secret_settings"), instanceOf(Map.class));
692714
@SuppressWarnings("unchecked")
693715
var secretSettings = (Map<String, Object>) model.secrets().get("secret_settings");
694716
assertThat(secretSettings.get("api_key"), Matchers.is(secrets));
695717
}
696718

719+
private static void assertModelWithoutSecrets(UnparsedModel model, Model expected) {
720+
assertThat(model.inferenceEntityId(), Matchers.is(expected.getInferenceEntityId()));
721+
assertThat(model.service(), Matchers.is(expected.getConfigurations().getService()));
722+
assertThat(model.taskType(), Matchers.is(expected.getConfigurations().getTaskType()));
723+
}
724+
697725
public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflictExists() {
698726
var secrets = "secret";
699727

@@ -873,6 +901,154 @@ public void testStoreModels_OnFailure_RemovesPartialWritesOfInferenceEndpoint()
873901
assertIndicesContainExpectedDocsCount(model3, 2);
874902
}
875903

904+
public void testStoreModels_Adds_OutOfSyncEndpoints_ToClusterState() {
905+
var inferenceId1 = "1";
906+
907+
var model = new ElasticInferenceServiceSparseEmbeddingsModel(
908+
inferenceId1,
909+
TaskType.SPARSE_EMBEDDING,
910+
ElasticInferenceService.NAME,
911+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings("model", null),
912+
EmptyTaskSettings.INSTANCE,
913+
EmptySecretSettings.INSTANCE,
914+
new ElasticInferenceServiceComponents("url"),
915+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
916+
);
917+
918+
storeModelDirectlyInIndexWithoutRegistry(model);
919+
920+
assertThat(modelRegistry.getInferenceIds(), not(hasItem(inferenceId1)));
921+
922+
PlainActionFuture<List<ModelStoreResponse>> storeListener = new PlainActionFuture<>();
923+
modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS);
924+
925+
var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS);
926+
assertThat(response.size(), is(1));
927+
assertThat(response.get(0).inferenceId(), is(model.getInferenceEntityId()));
928+
assertThat(response.get(0).status(), is(RestStatus.CONFLICT));
929+
assertTrue(response.get(0).failed());
930+
931+
// Storing the model fails because it already exists, but the registry should now be aware of the inference id in
932+
// cluster state
933+
var cause = response.get(0).failureCause();
934+
assertNotNull(cause);
935+
assertThat(cause, instanceOf(VersionConflictEngineException.class));
936+
assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists"));
937+
938+
assertIndicesContainExpectedDocsCount(model, 2);
939+
assertMinimalServiceSettings(modelRegistry, model);
940+
941+
var getModelWithSecretsListener = new PlainActionFuture<UnparsedModel>();
942+
modelRegistry.getModelWithSecrets(model.getInferenceEntityId(), getModelWithSecretsListener);
943+
944+
var unparsedModel = getModelWithSecretsListener.actionGet(TimeValue.THIRTY_SECONDS);
945+
946+
assertThat(unparsedModel.inferenceEntityId(), is(model.getInferenceEntityId()));
947+
assertThat(unparsedModel.service(), is(model.getConfigurations().getService()));
948+
assertThat(unparsedModel.taskType(), is(model.getConfigurations().getTaskType()));
949+
950+
assertThat(modelRegistry.getInferenceIds(), hasItem(inferenceId1));
951+
}
952+
953+
private void storeModelDirectlyInIndexWithoutRegistry(Model model) {
954+
var listener = new PlainActionFuture<BulkResponse>();
955+
956+
client().prepareBulk()
957+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
958+
.add(
959+
ModelRegistry.createIndexRequestBuilder(
960+
model.getInferenceEntityId(),
961+
InferenceIndex.INDEX_NAME,
962+
model.getConfigurations(),
963+
false,
964+
client()
965+
)
966+
)
967+
.add(
968+
ModelRegistry.createIndexRequestBuilder(
969+
model.getInferenceEntityId(),
970+
InferenceSecretsIndex.INDEX_NAME,
971+
model.getSecrets(),
972+
false,
973+
client()
974+
)
975+
)
976+
.execute(listener);
977+
978+
var bulkResponse = listener.actionGet(TimeValue.THIRTY_SECONDS);
979+
if (bulkResponse.hasFailures()) {
980+
fail("Failed to store model: " + bulkResponse.buildFailureMessage());
981+
}
982+
}
983+
984+
public void testStoreModels_Adds_OutOfSyncEndpoints_ToClusterState_MixedWithSuccessfulStore() {
985+
var inferenceId1 = "1";
986+
987+
var eisModel = new ElasticInferenceServiceSparseEmbeddingsModel(
988+
inferenceId1,
989+
TaskType.SPARSE_EMBEDDING,
990+
ElasticInferenceService.NAME,
991+
new ElasticInferenceServiceSparseEmbeddingsServiceSettings("model", null),
992+
EmptyTaskSettings.INSTANCE,
993+
EmptySecretSettings.INSTANCE,
994+
new ElasticInferenceServiceComponents("url"),
995+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
996+
);
997+
998+
storeModelDirectlyInIndexWithoutRegistry(eisModel);
999+
1000+
assertThat(modelRegistry.getInferenceIds(), not(hasItem(inferenceId1)));
1001+
1002+
var testModelId1 = "test-1";
1003+
var testModelId2 = "test-2";
1004+
1005+
var testModel1 = new TestSparseInferenceServiceExtension.TestSparseModel(
1006+
testModelId1,
1007+
new TestSparseInferenceServiceExtension.TestServiceSettings("model", "hidden_field", false)
1008+
);
1009+
1010+
var testModel2 = new TestSparseInferenceServiceExtension.TestSparseModel(
1011+
testModelId2,
1012+
new TestSparseInferenceServiceExtension.TestServiceSettings("model", "hidden_field", false)
1013+
);
1014+
1015+
PlainActionFuture<List<ModelStoreResponse>> storeListener = new PlainActionFuture<>();
1016+
modelRegistry.storeModels(List.of(eisModel, testModel1, testModel2), storeListener, TimeValue.THIRTY_SECONDS);
1017+
1018+
var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS);
1019+
assertThat(response.size(), is(3));
1020+
assertThat(response.get(0).inferenceId(), is(eisModel.getInferenceEntityId()));
1021+
assertThat(response.get(0).status(), is(RestStatus.CONFLICT));
1022+
assertTrue(response.get(0).failed());
1023+
1024+
assertThat(response.get(1), Matchers.is(new ModelStoreResponse(testModelId1, RestStatus.CREATED, null)));
1025+
assertThat(response.get(2), Matchers.is(new ModelStoreResponse(testModelId2, RestStatus.CREATED, null)));
1026+
1027+
// Storing the model fails because it already exists, but the registry should now be aware of the inference id in
1028+
// cluster state
1029+
var cause = response.get(0).failureCause();
1030+
assertNotNull(cause);
1031+
assertThat(cause, instanceOf(VersionConflictEngineException.class));
1032+
assertThat(cause.getMessage(), containsString("[model_1]: version conflict, document already exists"));
1033+
1034+
assertIndicesContainExpectedDocsCount(eisModel, 2);
1035+
assertMinimalServiceSettings(modelRegistry, eisModel);
1036+
1037+
var getModelWithSecretsListener = new PlainActionFuture<UnparsedModel>();
1038+
modelRegistry.getModelWithSecrets(eisModel.getInferenceEntityId(), getModelWithSecretsListener);
1039+
1040+
var unparsedModel = getModelWithSecretsListener.actionGet(TimeValue.THIRTY_SECONDS);
1041+
1042+
assertThat(unparsedModel.inferenceEntityId(), is(eisModel.getInferenceEntityId()));
1043+
assertThat(unparsedModel.service(), is(eisModel.getConfigurations().getService()));
1044+
assertThat(unparsedModel.taskType(), is(eisModel.getConfigurations().getTaskType()));
1045+
1046+
assertThat(modelRegistry.getInferenceIds(), is(Set.of(inferenceId1, testModelId1, testModelId2)));
1047+
1048+
assertModelAndMinimalSettingsWithoutSecrets(modelRegistry, testModel1);
1049+
assertModelAndMinimalSettingsWithoutSecrets(modelRegistry, testModel2);
1050+
}
1051+
8761052
public void testGetModelNoSecrets() {
8771053
var inferenceId = "1";
8781054

@@ -973,7 +1149,7 @@ private void storeCorruptedModelThenStoreModel(boolean storeSecrets) {
9731149
assertIndicesContainExpectedDocsCount(model, 0);
9741150
}
9751151

976-
private void assertIndicesContainExpectedDocsCount(TestModel model, int numberOfDocs) {
1152+
private void assertIndicesContainExpectedDocsCount(Model model, int numberOfDocs) {
9771153
SearchRequest modelSearch = client().prepareSearch(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN)
9781154
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(model.getInferenceEntityId()))))
9791155
.setSize(2)

0 commit comments

Comments
 (0)