|
67 | 67 | import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; |
68 | 68 | import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; |
69 | 69 | import org.elasticsearch.xpack.inference.services.ServiceFields; |
| 70 | +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; |
70 | 71 | import org.junit.After; |
71 | 72 | import org.junit.Before; |
72 | 73 | import org.mockito.ArgumentCaptor; |
73 | 74 | import org.mockito.Mockito; |
74 | 75 |
|
| 76 | +import java.io.IOException; |
75 | 77 | import java.util.ArrayList; |
76 | 78 | import java.util.EnumSet; |
77 | 79 | import java.util.HashMap; |
@@ -1509,6 +1511,84 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { |
1509 | 1511 | assertThat(model, is(expectedModel)); |
1510 | 1512 | } |
1511 | 1513 |
|
| 1514 | + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { |
| 1515 | + var client = mock(Client.class); |
| 1516 | + try (var service = createService(client)) { |
| 1517 | + var model = OpenAiChatCompletionModelTests.createChatCompletionModel( |
| 1518 | + randomAlphaOfLength(10), |
| 1519 | + randomAlphaOfLength(10), |
| 1520 | + randomAlphaOfLength(10), |
| 1521 | + randomAlphaOfLength(10), |
| 1522 | + randomAlphaOfLength(10) |
| 1523 | + ); |
| 1524 | + assertThrows( |
| 1525 | + ElasticsearchStatusException.class, |
| 1526 | + () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); } |
| 1527 | + ); |
| 1528 | + } |
| 1529 | + } |
| 1530 | + |
| 1531 | + public void testUpdateModelWithEmbeddingDetails_NonElandModelProvided() throws IOException { |
| 1532 | + var client = mock(Client.class); |
| 1533 | + try (var service = createService(client)) { |
| 1534 | + var originalModel = new MultilingualE5SmallModel( |
| 1535 | + randomAlphaOfLength(10), |
| 1536 | + TaskType.TEXT_EMBEDDING, |
| 1537 | + randomAlphaOfLength(10), |
| 1538 | + new MultilingualE5SmallInternalServiceSettings( |
| 1539 | + randomNonNegativeInt(), |
| 1540 | + randomNonNegativeInt(), |
| 1541 | + randomAlphaOfLength(10), |
| 1542 | + null |
| 1543 | + ), |
| 1544 | + null |
| 1545 | + ); |
| 1546 | + |
| 1547 | + var updatedModel = service.updateModelWithEmbeddingDetails(originalModel, randomNonNegativeInt()); |
| 1548 | + assertEquals(originalModel, updatedModel); |
| 1549 | + } |
| 1550 | + } |
| 1551 | + |
| 1552 | + public void testUpdateModelWithEmbeddingDetails_ElandModelProvided() throws IOException { |
| 1553 | + var client = mock(Client.class); |
| 1554 | + try (var service = createService(client)) { |
| 1555 | + var originalServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( |
| 1556 | + randomNonNegativeInt(), |
| 1557 | + randomNonNegativeInt(), |
| 1558 | + randomAlphaOfLength(10), |
| 1559 | + null |
| 1560 | + ); |
| 1561 | + var originalModel = new CustomElandEmbeddingModel( |
| 1562 | + randomAlphaOfLength(10), |
| 1563 | + TaskType.TEXT_EMBEDDING, |
| 1564 | + randomAlphaOfLength(10), |
| 1565 | + originalServiceSettings, |
| 1566 | + ChunkingSettingsTests.createRandomChunkingSettings() |
| 1567 | + ); |
| 1568 | + |
| 1569 | + var embeddingSize = randomNonNegativeInt(); |
| 1570 | + var expectedUpdatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( |
| 1571 | + originalServiceSettings.getNumAllocations(), |
| 1572 | + originalServiceSettings.getNumThreads(), |
| 1573 | + originalServiceSettings.modelId(), |
| 1574 | + originalServiceSettings.getAdaptiveAllocationsSettings(), |
| 1575 | + embeddingSize, |
| 1576 | + originalServiceSettings.similarity(), |
| 1577 | + originalServiceSettings.elementType() |
| 1578 | + ); |
| 1579 | + var expectedUpdatedModel = new CustomElandEmbeddingModel( |
| 1580 | + originalModel.getInferenceEntityId(), |
| 1581 | + originalModel.getTaskType(), |
| 1582 | + originalModel.getConfigurations().getService(), |
| 1583 | + expectedUpdatedServiceSettings, |
| 1584 | + originalModel.getConfigurations().getChunkingSettings() |
| 1585 | + ); |
| 1586 | + |
| 1587 | + var actualUpdatedModel = service.updateModelWithEmbeddingDetails(originalModel, embeddingSize); |
| 1588 | + assertEquals(expectedUpdatedModel, actualUpdatedModel); |
| 1589 | + } |
| 1590 | + } |
| 1591 | + |
1512 | 1592 | public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() { |
1513 | 1593 | { |
1514 | 1594 | assertFalse( |
|
0 commit comments