|
69 | 69 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; |
70 | 70 | import org.elasticsearch.xpack.inference.InferencePlugin; |
71 | 71 | import org.elasticsearch.xpack.inference.InputTypeTests; |
| 72 | +import org.elasticsearch.xpack.inference.ModelConfigurationsTests; |
72 | 73 | import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; |
73 | 74 | import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; |
74 | 75 | import org.elasticsearch.xpack.inference.services.ServiceFields; |
@@ -886,6 +887,75 @@ public void testParsePersistedConfig() { |
886 | 887 | } |
887 | 888 | } |
888 | 889 |
|
| 890 | + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() { |
| 891 | + var service = createService(mock(Client.class)); |
| 892 | + var model = new Model(ModelConfigurationsTests.createRandomInstance()); |
| 893 | + |
| 894 | + assertThrows(ElasticsearchStatusException.class, () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }); |
| 895 | + } |
| 896 | + |
| 897 | + public void testUpdateModelWithEmbeddingDetails_TextEmbeddingCustomElandEmbeddingsModelUpdatesDimensions() { |
| 898 | + var service = createService(mock(Client.class)); |
| 899 | + var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( |
| 900 | + 1, |
| 901 | + 4, |
| 902 | + "invalid", |
| 903 | + null, |
| 904 | + null, |
| 905 | + null, |
| 906 | + SimilarityMeasure.COSINE, |
| 907 | + DenseVectorFieldMapper.ElementType.FLOAT |
| 908 | + ); |
| 909 | + var model = new CustomElandEmbeddingModel( |
| 910 | + randomAlphaOfLength(10), |
| 911 | + TaskType.TEXT_EMBEDDING, |
| 912 | + "elasticsearch", |
| 913 | + elandServiceSettings, |
| 914 | + null |
| 915 | + ); |
| 916 | + |
| 917 | + var embeddingSize = randomNonNegativeInt(); |
| 918 | + var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); |
| 919 | + |
| 920 | + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); |
| 921 | + } |
| 922 | + |
| 923 | + public void testUpdateModelWithEmbeddingDetails_NonTextEmbeddingCustomElandEmbeddingsModelNotModified() { |
| 924 | + var service = createService(mock(Client.class)); |
| 925 | + var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( |
| 926 | + 1, |
| 927 | + 4, |
| 928 | + "invalid", |
| 929 | + null, |
| 930 | + null, |
| 931 | + null, |
| 932 | + SimilarityMeasure.COSINE, |
| 933 | + DenseVectorFieldMapper.ElementType.FLOAT |
| 934 | + ); |
| 935 | + var model = new CustomElandEmbeddingModel( |
| 936 | + randomAlphaOfLength(10), |
| 937 | + TaskType.SPARSE_EMBEDDING, |
| 938 | + "elasticsearch", |
| 939 | + elandServiceSettings, |
| 940 | + null |
| 941 | + ); |
| 942 | + |
| 943 | + var embeddingSize = randomNonNegativeInt(); |
| 944 | + var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); |
| 945 | + |
| 946 | + assertEquals(model, updatedModel); |
| 947 | + } |
| 948 | + |
| 949 | + public void testUpdateModelWithEmbeddingDetails_ElasticsearchInternalModelNotModified() { |
| 950 | + var service = createService(mock(Client.class)); |
| 951 | + var model = mock(ElasticsearchInternalModel.class); |
| 952 | + |
| 953 | + var updatedModel = service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); |
| 954 | + |
| 955 | + assertEquals(model, updatedModel); |
| 956 | + verifyNoMoreInteractions(model); |
| 957 | + } |
| 958 | + |
889 | 959 | public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException { |
890 | 960 | testChunkInfer_e5(null); |
891 | 961 | } |
|
0 commit comments