|
79 | 79 | import java.io.IOException;
|
80 | 80 | import java.util.ArrayList;
|
81 | 81 | import java.util.Arrays;
|
| 82 | +import java.util.Collections; |
82 | 83 | import java.util.EnumSet;
|
83 | 84 | import java.util.HashMap;
|
84 | 85 | import java.util.List;
|
|
109 | 110 | import static org.mockito.Mockito.doAnswer;
|
110 | 111 | import static org.mockito.Mockito.mock;
|
111 | 112 | import static org.mockito.Mockito.verify;
|
| 113 | +import static org.mockito.Mockito.verifyNoMoreInteractions; |
112 | 114 | import static org.mockito.Mockito.when;
|
113 | 115 |
|
114 | 116 | public class ElasticsearchInternalServiceTests extends ESTestCase {
|
@@ -1640,6 +1642,148 @@ public void testGetConfiguration() throws Exception {
|
1640 | 1642 | }
|
1641 | 1643 | }
|
1642 | 1644 |
|
| 1645 | + public void testUpdateModelsWithDynamicFields_NoModelsToUpdate() throws Exception { |
| 1646 | + ActionListener<List<Model>> resultsListener = ActionListener.<List<Model>>wrap( |
| 1647 | + updatedModels -> assertEquals(Collections.emptyList(), updatedModels), |
| 1648 | + e -> fail("Unexpected exception: " + e) |
| 1649 | + ); |
| 1650 | + |
| 1651 | + try (var service = createService(mock(Client.class))) { |
| 1652 | + service.updateModelsWithDynamicFields(List.of(), resultsListener); |
| 1653 | + } |
| 1654 | + } |
| 1655 | + |
| 1656 | + public void testUpdateModelsWithDynamicFields_InvalidModelProvided() throws IOException { |
| 1657 | + ActionListener<List<Model>> resultsListener = ActionListener.wrap( |
| 1658 | + updatedModels -> fail("Expected invalid model assertion error to be thrown"), |
| 1659 | + e -> fail("Expected invalid model assertion error to be thrown") |
| 1660 | + ); |
| 1661 | + |
| 1662 | + try (var service = createService(mock(Client.class))) { |
| 1663 | + assertThrows( |
| 1664 | + AssertionError.class, |
| 1665 | + () -> { service.updateModelsWithDynamicFields(List.of(mock(Model.class)), resultsListener); } |
| 1666 | + ); |
| 1667 | + } |
| 1668 | + } |
| 1669 | + |
| 1670 | + @SuppressWarnings("unchecked") |
| 1671 | + public void testUpdateModelsWithDynamicFields_FailsToRetrieveDeployments() throws IOException { |
| 1672 | + var deploymentId = randomAlphaOfLength(10); |
| 1673 | + var model = mock(ElasticsearchInternalModel.class); |
| 1674 | + when(model.mlNodeDeploymentId()).thenReturn(deploymentId); |
| 1675 | + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); |
| 1676 | + |
| 1677 | + ActionListener<List<Model>> resultsListener = ActionListener.wrap(updatedModels -> { |
| 1678 | + assertEquals(updatedModels.size(), 1); |
| 1679 | + verify(model).mlNodeDeploymentId(); |
| 1680 | + verifyNoMoreInteractions(model); |
| 1681 | + }, e -> fail("Expected original models to be returned")); |
| 1682 | + |
| 1683 | + var client = mock(Client.class); |
| 1684 | + when(client.threadPool()).thenReturn(threadPool); |
| 1685 | + doAnswer(invocation -> { |
| 1686 | + var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocation.getArguments()[2]; |
| 1687 | + listener.onFailure(new RuntimeException(randomAlphaOfLength(10))); |
| 1688 | + return null; |
| 1689 | + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); |
| 1690 | + |
| 1691 | + try (var service = createService(client)) { |
| 1692 | + service.updateModelsWithDynamicFields(List.of(model), resultsListener); |
| 1693 | + } |
| 1694 | + } |
| 1695 | + |
| 1696 | + public void testUpdateModelsWithDynamicFields_SingleModelToUpdate() throws IOException { |
| 1697 | + var deploymentId = randomAlphaOfLength(10); |
| 1698 | + var model = mock(ElasticsearchInternalModel.class); |
| 1699 | + when(model.mlNodeDeploymentId()).thenReturn(deploymentId); |
| 1700 | + when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); |
| 1701 | + |
| 1702 | + var modelsByDeploymentId = new HashMap<String, List<Model>>(); |
| 1703 | + modelsByDeploymentId.put(deploymentId, List.of(model)); |
| 1704 | + |
| 1705 | + testUpdateModelsWithDynamicFields(modelsByDeploymentId); |
| 1706 | + } |
| 1707 | + |
| 1708 | + public void testUpdateModelsWithDynamicFields_MultipleModelsWithDifferentDeploymentsToUpdate() throws IOException { |
| 1709 | + var deploymentId1 = randomAlphaOfLength(10); |
| 1710 | + var model1 = mock(ElasticsearchInternalModel.class); |
| 1711 | + when(model1.mlNodeDeploymentId()).thenReturn(deploymentId1); |
| 1712 | + when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); |
| 1713 | + var deploymentId2 = randomAlphaOfLength(10); |
| 1714 | + var model2 = mock(ElasticsearchInternalModel.class); |
| 1715 | + when(model2.mlNodeDeploymentId()).thenReturn(deploymentId2); |
| 1716 | + when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); |
| 1717 | + |
| 1718 | + var modelsByDeploymentId = new HashMap<String, List<Model>>(); |
| 1719 | + modelsByDeploymentId.put(deploymentId1, List.of(model1)); |
| 1720 | + modelsByDeploymentId.put(deploymentId2, List.of(model2)); |
| 1721 | + |
| 1722 | + testUpdateModelsWithDynamicFields(modelsByDeploymentId); |
| 1723 | + } |
| 1724 | + |
| 1725 | + public void testUpdateModelsWithDynamicFields_MultipleModelsWithSameDeploymentsToUpdate() throws IOException { |
| 1726 | + var deploymentId = randomAlphaOfLength(10); |
| 1727 | + var model1 = mock(ElasticsearchInternalModel.class); |
| 1728 | + when(model1.mlNodeDeploymentId()).thenReturn(deploymentId); |
| 1729 | + when(model1.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); |
| 1730 | + var model2 = mock(ElasticsearchInternalModel.class); |
| 1731 | + when(model2.mlNodeDeploymentId()).thenReturn(deploymentId); |
| 1732 | + when(model2.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); |
| 1733 | + |
| 1734 | + var modelsByDeploymentId = new HashMap<String, List<Model>>(); |
| 1735 | + modelsByDeploymentId.put(deploymentId, List.of(model1, model2)); |
| 1736 | + |
| 1737 | + testUpdateModelsWithDynamicFields(modelsByDeploymentId); |
| 1738 | + } |
| 1739 | + |
| 1740 | + @SuppressWarnings("unchecked") |
| 1741 | + private void testUpdateModelsWithDynamicFields(Map<String, List<Model>> modelsByDeploymentId) throws IOException { |
| 1742 | + var modelsToUpdate = new ArrayList<Model>(); |
| 1743 | + modelsByDeploymentId.values().forEach(modelsToUpdate::addAll); |
| 1744 | + |
| 1745 | + var updatedNumberOfAllocations = new HashMap<String, Integer>(); |
| 1746 | + modelsByDeploymentId.keySet().forEach(deploymentId -> updatedNumberOfAllocations.put(deploymentId, randomIntBetween(1, 10))); |
| 1747 | + |
| 1748 | + ActionListener<List<Model>> resultsListener = ActionListener.wrap(updatedModels -> { |
| 1749 | + assertEquals(updatedModels.size(), modelsToUpdate.size()); |
| 1750 | + modelsByDeploymentId.forEach((deploymentId, models) -> { |
| 1751 | + var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId); |
| 1752 | + models.forEach(model -> { |
| 1753 | + verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations); |
| 1754 | + verify((ElasticsearchInternalModel) model).mlNodeDeploymentId(); |
| 1755 | + verifyNoMoreInteractions(model); |
| 1756 | + }); |
| 1757 | + }); |
| 1758 | + }, e -> fail("Unexpected exception: " + e)); |
| 1759 | + |
| 1760 | + var client = mock(Client.class); |
| 1761 | + when(client.threadPool()).thenReturn(threadPool); |
| 1762 | + doAnswer(invocation -> { |
| 1763 | + var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocation.getArguments()[2]; |
| 1764 | + var mockAssignmentStats = new ArrayList<AssignmentStats>(); |
| 1765 | + modelsByDeploymentId.keySet().forEach(deploymentId -> { |
| 1766 | + var mockAssignmentStatsForDeploymentId = mock(AssignmentStats.class); |
| 1767 | + when(mockAssignmentStatsForDeploymentId.getDeploymentId()).thenReturn(deploymentId); |
| 1768 | + when(mockAssignmentStatsForDeploymentId.getNumberOfAllocations()).thenReturn(updatedNumberOfAllocations.get(deploymentId)); |
| 1769 | + mockAssignmentStats.add(mockAssignmentStatsForDeploymentId); |
| 1770 | + }); |
| 1771 | + listener.onResponse( |
| 1772 | + new GetDeploymentStatsAction.Response( |
| 1773 | + Collections.emptyList(), |
| 1774 | + Collections.emptyList(), |
| 1775 | + mockAssignmentStats, |
| 1776 | + mockAssignmentStats.size() |
| 1777 | + ) |
| 1778 | + ); |
| 1779 | + return null; |
| 1780 | + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); |
| 1781 | + |
| 1782 | + try (var service = createService(client)) { |
| 1783 | + service.updateModelsWithDynamicFields(modelsToUpdate, resultsListener); |
| 1784 | + } |
| 1785 | + } |
| 1786 | + |
1643 | 1787 | public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException {
|
1644 | 1788 | var cs = mock(ClusterService.class);
|
1645 | 1789 | var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
|
|
0 commit comments