Skip to content

Commit fdd92d8

Browse files
Fix ELAND endpoints not updating dimensions (elastic#126537) (elastic#126549)
* Fix ELAND endpoints not updating dimensions * Update docs/changelog/126537.yaml
1 parent c08f32d commit fdd92d8

File tree

3 files changed

+101
-18
lines changed

3 files changed

+101
-18
lines changed

docs/changelog/126537.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126537
2+
summary: Fix ELAND endpoints not updating dimensions
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
5757
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
5858
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
59+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
5960

6061
import java.util.ArrayList;
6162
import java.util.Collections;
@@ -535,25 +536,32 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
535536
}
536537
}
537538

538-
private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) {
539-
CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
540-
model.getServiceSettings().getNumAllocations(),
541-
model.getServiceSettings().getNumThreads(),
542-
model.getServiceSettings().modelId(),
543-
model.getServiceSettings().getAdaptiveAllocationsSettings(),
544-
model.getServiceSettings().getDeploymentId(),
545-
embeddingSize,
546-
model.getServiceSettings().similarity(),
547-
model.getServiceSettings().elementType()
548-
);
539+
@Override
540+
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
541+
if (model instanceof CustomElandEmbeddingModel customElandEmbeddingModel && model.getTaskType() == TaskType.TEXT_EMBEDDING) {
542+
CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
543+
customElandEmbeddingModel.getServiceSettings().getNumAllocations(),
544+
customElandEmbeddingModel.getServiceSettings().getNumThreads(),
545+
customElandEmbeddingModel.getServiceSettings().modelId(),
546+
customElandEmbeddingModel.getServiceSettings().getAdaptiveAllocationsSettings(),
547+
customElandEmbeddingModel.getServiceSettings().getDeploymentId(),
548+
embeddingSize,
549+
customElandEmbeddingModel.getServiceSettings().similarity(),
550+
customElandEmbeddingModel.getServiceSettings().elementType()
551+
);
549552

550-
return new CustomElandEmbeddingModel(
551-
model.getInferenceEntityId(),
552-
model.getTaskType(),
553-
model.getConfigurations().getService(),
554-
serviceSettings,
555-
model.getConfigurations().getChunkingSettings()
556-
);
553+
return new CustomElandEmbeddingModel(
554+
customElandEmbeddingModel.getInferenceEntityId(),
555+
customElandEmbeddingModel.getTaskType(),
556+
customElandEmbeddingModel.getConfigurations().getService(),
557+
serviceSettings,
558+
customElandEmbeddingModel.getConfigurations().getChunkingSettings()
559+
);
560+
} else if (model instanceof ElasticsearchInternalModel) {
561+
return model;
562+
} else {
563+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
564+
}
557565
}
558566

559567
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
6969
import org.elasticsearch.xpack.inference.InferencePlugin;
7070
import org.elasticsearch.xpack.inference.InputTypeTests;
71+
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
7172
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
7273
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
7374
import org.elasticsearch.xpack.inference.services.ServiceFields;
@@ -859,6 +860,75 @@ public void testParsePersistedConfig() {
859860
}
860861
}
861862

863+
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() {
864+
var service = createService(mock(Client.class));
865+
var model = new Model(ModelConfigurationsTests.createRandomInstance());
866+
867+
assertThrows(ElasticsearchStatusException.class, () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); });
868+
}
869+
870+
public void testUpdateModelWithEmbeddingDetails_TextEmbeddingCustomElandEmbeddingsModelUpdatesDimensions() {
871+
var service = createService(mock(Client.class));
872+
var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
873+
1,
874+
4,
875+
"invalid",
876+
null,
877+
null,
878+
null,
879+
SimilarityMeasure.COSINE,
880+
DenseVectorFieldMapper.ElementType.FLOAT
881+
);
882+
var model = new CustomElandEmbeddingModel(
883+
randomAlphaOfLength(10),
884+
TaskType.TEXT_EMBEDDING,
885+
"elasticsearch",
886+
elandServiceSettings,
887+
null
888+
);
889+
890+
var embeddingSize = randomNonNegativeInt();
891+
var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
892+
893+
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
894+
}
895+
896+
public void testUpdateModelWithEmbeddingDetails_NonTextEmbeddingCustomElandEmbeddingsModelNotModified() {
897+
var service = createService(mock(Client.class));
898+
var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
899+
1,
900+
4,
901+
"invalid",
902+
null,
903+
null,
904+
null,
905+
SimilarityMeasure.COSINE,
906+
DenseVectorFieldMapper.ElementType.FLOAT
907+
);
908+
var model = new CustomElandEmbeddingModel(
909+
randomAlphaOfLength(10),
910+
TaskType.SPARSE_EMBEDDING,
911+
"elasticsearch",
912+
elandServiceSettings,
913+
null
914+
);
915+
916+
var embeddingSize = randomNonNegativeInt();
917+
var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
918+
919+
assertEquals(model, updatedModel);
920+
}
921+
922+
public void testUpdateModelWithEmbeddingDetails_ElasticsearchInternalModelNotModified() {
923+
var service = createService(mock(Client.class));
924+
var model = mock(ElasticsearchInternalModel.class);
925+
926+
var updatedModel = service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt());
927+
928+
assertEquals(model, updatedModel);
929+
verifyNoMoreInteractions(model);
930+
}
931+
862932
public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException {
863933
testChunkInfer_e5(null);
864934
}

0 commit comments

Comments
 (0)