|
68 | 68 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; |
69 | 69 | import org.elasticsearch.xpack.inference.InferencePlugin; |
70 | 70 | import org.elasticsearch.xpack.inference.InputTypeTests; |
| 71 | +import org.elasticsearch.xpack.inference.ModelConfigurationsTests; |
71 | 72 | import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; |
72 | 73 | import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; |
73 | 74 | import org.elasticsearch.xpack.inference.services.ServiceFields; |
@@ -859,6 +860,75 @@ public void testParsePersistedConfig() { |
859 | 860 | } |
860 | 861 | } |
861 | 862 |
|
| 863 | + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() { |
| 864 | + var service = createService(mock(Client.class)); |
| 865 | + var model = new Model(ModelConfigurationsTests.createRandomInstance()); |
| 866 | + |
| 867 | + assertThrows(ElasticsearchStatusException.class, () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }); |
| 868 | + } |
| 869 | + |
| 870 | + public void testUpdateModelWithEmbeddingDetails_TextEmbeddingCustomElandEmbeddingsModelUpdatesDimensions() { |
| 871 | + var service = createService(mock(Client.class)); |
| 872 | + var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( |
| 873 | + 1, |
| 874 | + 4, |
| 875 | + "invalid", |
| 876 | + null, |
| 877 | + null, |
| 878 | + null, |
| 879 | + SimilarityMeasure.COSINE, |
| 880 | + DenseVectorFieldMapper.ElementType.FLOAT |
| 881 | + ); |
| 882 | + var model = new CustomElandEmbeddingModel( |
| 883 | + randomAlphaOfLength(10), |
| 884 | + TaskType.TEXT_EMBEDDING, |
| 885 | + "elasticsearch", |
| 886 | + elandServiceSettings, |
| 887 | + null |
| 888 | + ); |
| 889 | + |
| 890 | + var embeddingSize = randomNonNegativeInt(); |
| 891 | + var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); |
| 892 | + |
| 893 | + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); |
| 894 | + } |
| 895 | + |
| 896 | + public void testUpdateModelWithEmbeddingDetails_NonTextEmbeddingCustomElandEmbeddingsModelNotModified() { |
| 897 | + var service = createService(mock(Client.class)); |
| 898 | + var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( |
| 899 | + 1, |
| 900 | + 4, |
| 901 | + "invalid", |
| 902 | + null, |
| 903 | + null, |
| 904 | + null, |
| 905 | + SimilarityMeasure.COSINE, |
| 906 | + DenseVectorFieldMapper.ElementType.FLOAT |
| 907 | + ); |
| 908 | + var model = new CustomElandEmbeddingModel( |
| 909 | + randomAlphaOfLength(10), |
| 910 | + TaskType.SPARSE_EMBEDDING, |
| 911 | + "elasticsearch", |
| 912 | + elandServiceSettings, |
| 913 | + null |
| 914 | + ); |
| 915 | + |
| 916 | + var embeddingSize = randomNonNegativeInt(); |
| 917 | + var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); |
| 918 | + |
| 919 | + assertEquals(model, updatedModel); |
| 920 | + } |
| 921 | + |
| 922 | + public void testUpdateModelWithEmbeddingDetails_ElasticsearchInternalModelNotModified() { |
| 923 | + var service = createService(mock(Client.class)); |
| 924 | + var model = mock(ElasticsearchInternalModel.class); |
| 925 | + |
| 926 | + var updatedModel = service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); |
| 927 | + |
| 928 | + assertEquals(model, updatedModel); |
| 929 | + verifyNoMoreInteractions(model); |
| 930 | + } |
| 931 | + |
862 | 932 | public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException { |
863 | 933 | testChunkInfer_e5(null); |
864 | 934 | } |
|
0 commit comments