Skip to content

Commit 44507cc

Browse files
Fix ELAND endpoints not updating dimensions (#126537)
* Fix ELAND endpoints not updating dimensions * Update docs/changelog/126537.yaml
1 parent 7d7fa76 commit 44507cc

File tree

3 files changed

+101
-18
lines changed

3 files changed

+101
-18
lines changed

docs/changelog/126537.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126537
2+
summary: Fix ELAND endpoints not updating dimensions
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -561,25 +561,33 @@ private void migrateModelVersionToModelId(Map<String, Object> serviceSettingsMap
561561
}
562562
}
563563

564-
private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) {
565-
CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
566-
model.getServiceSettings().getNumAllocations(),
567-
model.getServiceSettings().getNumThreads(),
568-
model.getServiceSettings().modelId(),
569-
model.getServiceSettings().getAdaptiveAllocationsSettings(),
570-
model.getServiceSettings().getDeploymentId(),
571-
embeddingSize,
572-
model.getServiceSettings().similarity(),
573-
model.getServiceSettings().elementType()
574-
);
564+
@Override
565+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
566+
if (model instanceof CustomElandEmbeddingModel customElandEmbeddingModel && model.getTaskType() == TaskType.TEXT_EMBEDDING) {
567+
CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
568+
customElandEmbeddingModel.getServiceSettings().getNumAllocations(),
569+
customElandEmbeddingModel.getServiceSettings().getNumThreads(),
570+
customElandEmbeddingModel.getServiceSettings().modelId(),
571+
customElandEmbeddingModel.getServiceSettings().getAdaptiveAllocationsSettings(),
572+
customElandEmbeddingModel.getServiceSettings().getDeploymentId(),
573+
embeddingSize,
574+
customElandEmbeddingModel.getServiceSettings().similarity(),
575+
customElandEmbeddingModel.getServiceSettings().elementType()
576+
);
577+
578+
return new CustomElandEmbeddingModel(
579+
customElandEmbeddingModel.getInferenceEntityId(),
580+
customElandEmbeddingModel.getTaskType(),
581+
customElandEmbeddingModel.getConfigurations().getService(),
582+
serviceSettings,
583+
customElandEmbeddingModel.getConfigurations().getChunkingSettings()
584+
);
585+
} else if (model instanceof ElasticsearchInternalModel) {
586+
return model;
587+
} else {
588+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
589+
}
575590

576-
return new CustomElandEmbeddingModel(
577-
model.getInferenceEntityId(),
578-
model.getTaskType(),
579-
model.getConfigurations().getService(),
580-
serviceSettings,
581-
model.getConfigurations().getChunkingSettings()
582-
);
583591
}
584592

585593
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
7070
import org.elasticsearch.xpack.inference.InferencePlugin;
7171
import org.elasticsearch.xpack.inference.InputTypeTests;
72+
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
7273
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
7374
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
7475
import org.elasticsearch.xpack.inference.services.ServiceFields;
@@ -886,6 +887,75 @@ public void testParsePersistedConfig() {
886887
}
887888
}
888889

890+
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() {
891+
var service = createService(mock(Client.class));
892+
var model = new Model(ModelConfigurationsTests.createRandomInstance());
893+
894+
assertThrows(ElasticsearchStatusException.class, () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); });
895+
}
896+
897+
public void testUpdateModelWithEmbeddingDetails_TextEmbeddingCustomElandEmbeddingsModelUpdatesDimensions() {
898+
var service = createService(mock(Client.class));
899+
var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
900+
1,
901+
4,
902+
"invalid",
903+
null,
904+
null,
905+
null,
906+
SimilarityMeasure.COSINE,
907+
DenseVectorFieldMapper.ElementType.FLOAT
908+
);
909+
var model = new CustomElandEmbeddingModel(
910+
randomAlphaOfLength(10),
911+
TaskType.TEXT_EMBEDDING,
912+
"elasticsearch",
913+
elandServiceSettings,
914+
null
915+
);
916+
917+
var embeddingSize = randomNonNegativeInt();
918+
var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
919+
920+
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
921+
}
922+
923+
public void testUpdateModelWithEmbeddingDetails_NonTextEmbeddingCustomElandEmbeddingsModelNotModified() {
924+
var service = createService(mock(Client.class));
925+
var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
926+
1,
927+
4,
928+
"invalid",
929+
null,
930+
null,
931+
null,
932+
SimilarityMeasure.COSINE,
933+
DenseVectorFieldMapper.ElementType.FLOAT
934+
);
935+
var model = new CustomElandEmbeddingModel(
936+
randomAlphaOfLength(10),
937+
TaskType.SPARSE_EMBEDDING,
938+
"elasticsearch",
939+
elandServiceSettings,
940+
null
941+
);
942+
943+
var embeddingSize = randomNonNegativeInt();
944+
var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
945+
946+
assertEquals(model, updatedModel);
947+
}
948+
949+
public void testUpdateModelWithEmbeddingDetails_ElasticsearchInternalModelNotModified() {
950+
var service = createService(mock(Client.class));
951+
var model = mock(ElasticsearchInternalModel.class);
952+
953+
var updatedModel = service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt());
954+
955+
assertEquals(model, updatedModel);
956+
verifyNoMoreInteractions(model);
957+
}
958+
889959
public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException {
890960
testChunkInfer_e5(null);
891961
}

0 commit comments

Comments
 (0)