Skip to content

Commit f80f054

Browse files
Adding updateModelWithEmbeddingDetails tests
1 parent 9cbc5d3 commit f80f054

File tree

2 files changed

+107
-22
lines changed

2 files changed

+107
-22
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
6060
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
6161
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
62+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
6263
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
6364

6465
import java.util.ArrayList;
@@ -504,30 +505,34 @@ public void checkModelConfig(Model model, ActionListener<Model> listener) {
504505

505506
@Override
506507
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
507-
if (model instanceof CustomElandEmbeddingModel embeddingsModel) {
508-
var serviceSettings = embeddingsModel.getServiceSettings();
509-
510-
var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
511-
serviceSettings.getNumAllocations(),
512-
serviceSettings.getNumThreads(),
513-
serviceSettings.modelId(),
514-
serviceSettings.getAdaptiveAllocationsSettings(),
515-
embeddingSize,
516-
serviceSettings.similarity(),
517-
serviceSettings.elementType()
518-
);
508+
if (model instanceof ElasticsearchInternalModel) {
509+
if (model instanceof CustomElandEmbeddingModel embeddingsModel) {
510+
var serviceSettings = embeddingsModel.getServiceSettings();
511+
512+
var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
513+
serviceSettings.getNumAllocations(),
514+
serviceSettings.getNumThreads(),
515+
serviceSettings.modelId(),
516+
serviceSettings.getAdaptiveAllocationsSettings(),
517+
embeddingSize,
518+
serviceSettings.similarity(),
519+
serviceSettings.elementType()
520+
);
519521

520-
return new CustomElandEmbeddingModel(
521-
model.getInferenceEntityId(),
522-
model.getTaskType(),
523-
model.getConfigurations().getService(),
524-
updatedServiceSettings,
525-
model.getConfigurations().getChunkingSettings()
526-
);
522+
return new CustomElandEmbeddingModel(
523+
model.getInferenceEntityId(),
524+
model.getTaskType(),
525+
model.getConfigurations().getService(),
526+
updatedServiceSettings,
527+
model.getConfigurations().getChunkingSettings()
528+
);
529+
} else {
530+
// TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do
531+
// need to update the dimensions?
532+
return model;
533+
}
527534
} else {
528-
// TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do
529-
// need to update the dimensions?
530-
return model;
535+
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
531536
}
532537
}
533538

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,13 @@
6767
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
6868
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
6969
import org.elasticsearch.xpack.inference.services.ServiceFields;
70+
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests;
7071
import org.junit.After;
7172
import org.junit.Before;
7273
import org.mockito.ArgumentCaptor;
7374
import org.mockito.Mockito;
7475

76+
import java.io.IOException;
7577
import java.util.ArrayList;
7678
import java.util.EnumSet;
7779
import java.util.HashMap;
@@ -1509,6 +1511,84 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() {
15091511
assertThat(model, is(expectedModel));
15101512
}
15111513

1514+
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
1515+
var client = mock(Client.class);
1516+
try (var service = createService(client)) {
1517+
var model = OpenAiChatCompletionModelTests.createChatCompletionModel(
1518+
randomAlphaOfLength(10),
1519+
randomAlphaOfLength(10),
1520+
randomAlphaOfLength(10),
1521+
randomAlphaOfLength(10),
1522+
randomAlphaOfLength(10)
1523+
);
1524+
assertThrows(
1525+
ElasticsearchStatusException.class,
1526+
() -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
1527+
);
1528+
}
1529+
}
1530+
1531+
public void testUpdateModelWithEmbeddingDetails_NonElandModelProvided() throws IOException {
1532+
var client = mock(Client.class);
1533+
try (var service = createService(client)) {
1534+
var originalModel = new MultilingualE5SmallModel(
1535+
randomAlphaOfLength(10),
1536+
TaskType.TEXT_EMBEDDING,
1537+
randomAlphaOfLength(10),
1538+
new MultilingualE5SmallInternalServiceSettings(
1539+
randomNonNegativeInt(),
1540+
randomNonNegativeInt(),
1541+
randomAlphaOfLength(10),
1542+
null
1543+
),
1544+
null
1545+
);
1546+
1547+
var updatedModel = service.updateModelWithEmbeddingDetails(originalModel, randomNonNegativeInt());
1548+
assertEquals(originalModel, updatedModel);
1549+
}
1550+
}
1551+
1552+
public void testUpdateModelWithEmbeddingDetails_ElandModelProvided() throws IOException {
1553+
var client = mock(Client.class);
1554+
try (var service = createService(client)) {
1555+
var originalServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
1556+
randomNonNegativeInt(),
1557+
randomNonNegativeInt(),
1558+
randomAlphaOfLength(10),
1559+
null
1560+
);
1561+
var originalModel = new CustomElandEmbeddingModel(
1562+
randomAlphaOfLength(10),
1563+
TaskType.TEXT_EMBEDDING,
1564+
randomAlphaOfLength(10),
1565+
originalServiceSettings,
1566+
ChunkingSettingsTests.createRandomChunkingSettings()
1567+
);
1568+
1569+
var embeddingSize = randomNonNegativeInt();
1570+
var expectedUpdatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
1571+
originalServiceSettings.getNumAllocations(),
1572+
originalServiceSettings.getNumThreads(),
1573+
originalServiceSettings.modelId(),
1574+
originalServiceSettings.getAdaptiveAllocationsSettings(),
1575+
embeddingSize,
1576+
originalServiceSettings.similarity(),
1577+
originalServiceSettings.elementType()
1578+
);
1579+
var expectedUpdatedModel = new CustomElandEmbeddingModel(
1580+
originalModel.getInferenceEntityId(),
1581+
originalModel.getTaskType(),
1582+
originalModel.getConfigurations().getService(),
1583+
expectedUpdatedServiceSettings,
1584+
originalModel.getConfigurations().getChunkingSettings()
1585+
);
1586+
1587+
var actualUpdatedModel = service.updateModelWithEmbeddingDetails(originalModel, embeddingSize);
1588+
assertEquals(expectedUpdatedModel, actualUpdatedModel);
1589+
}
1590+
}
1591+
15121592
public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() {
15131593
{
15141594
assertFalse(

0 commit comments

Comments
 (0)