diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java index 38c2806778dff..38707be18f381 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java @@ -35,6 +35,7 @@ import java.util.Comparator; import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.function.BiFunction; import java.util.stream.Collectors; @@ -633,7 +634,9 @@ private DiscoveryNode findRelocationTarget( Set desiredNodeIds, BiFunction canAllocateDecider ) { - for (final var nodeId : desiredNodeIds) { + // First sort by allocation ordering so we distribute relocated shards evenly + final List allocationPreference = allocationOrdering.sort(desiredNodeIds); + for (final var nodeId : allocationPreference) { // TODO consider ignored nodes here too? if (nodeId.equals(shardRouting.currentNodeId())) { continue; @@ -645,6 +648,7 @@ private DiscoveryNode findRelocationTarget( final var decision = canAllocateDecider.apply(shardRouting, node); logger.trace("relocate {} to {}: {}", shardRouting, nodeId, decision); if (decision.type() == Decision.Type.YES) { + allocationOrdering.recordAllocation(nodeId); return node.node(); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java index 4eed552d5f1af..1b4d98e6c6b91 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconcilerTests.java @@ -74,6 +74,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.junit.BeforeClass; +import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; @@ -88,6 +89,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiPredicate; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -107,6 +109,7 @@ import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.oneOf; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -982,6 +985,79 @@ public Decision canAllocate(ShardRouting shardRouting, RoutingAllocation allocat assertThat(shuttingDownState.getRoutingNodes().node("node-2").numberOfShardsWithState(ShardRoutingState.INITIALIZING), equalTo(1)); } + /** + * Simulate many nodes leaving a cluster, ensure that + * shards that are reallocated evenly among desired + * candidates + */ + public void testShardsAreRelocatedEvenly() { + final var numNodes = randomIntBetween(6, 10); + final var numToRemain = randomIntBetween(2, numNodes - 1); + final var discoveryNodes = discoveryNodes(numNodes); + final var numberOfIndices = randomIntBetween(6, 20); + + final Metadata.Builder metadataBuilder = Metadata.builder(); + final RoutingTable.Builder routingTableBuilder = RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY); + for (int i = 0; i < numberOfIndices; i++) { + final String indexName = "index-" + i; + final IndexMetadata indexMetadata = randomPriorityIndex(indexName, 1, 0); + metadataBuilder.put(indexMetadata, true); + routingTableBuilder.addAsNew(indexMetadata); + } + + var clusterState = ClusterState.builder(ClusterName.DEFAULT) + .nodes(discoveryNodes) + .metadata(metadataBuilder.build()) + .routingTable(routingTableBuilder.build()) + .build(); + + Function nodeOrdinal = (String nodeName) -> Integer.parseInt(nodeName.split("-")[1]); + + // Initially put all shards on the nodes that are going to leave the cluster + AtomicReference db = new AtomicReference<>( + desiredBalance(clusterState, (shardId, nodeId) -> nodeOrdinal.apply(nodeId) >= numToRemain) + ); + final var allocationService = createTestAllocationService(routingAllocation -> reconcile(routingAllocation, db.get())); + clusterState = fullyReconcile(allocationService, clusterState); + logger.info("Initial state: {}", shardCounts(clusterState)); + + // Recalculate desired balance, marking only remaining nodes as desired + db.set(desiredBalance(clusterState, (shardId, nodeId) -> nodeOrdinal.apply(nodeId) < numToRemain)); + + // Reconcile it + clusterState = fullyReconcile(allocationService, clusterState); + logger.info("State after shutdowns: {}", shardCounts(clusterState)); + + Map allocationCounts = shardCounts(clusterState); + + // Only the remaining nodes should have allocations + assertTrue(allocationCounts.keySet().stream().allMatch(nodeId -> nodeOrdinal.apply(nodeId) < numToRemain)); + + // ... and the shards should be spread as evenly as possible over them + int[] remainingNodeShardCounts = IntStream.range(0, numToRemain - 1) + .map(ordinal -> allocationCounts.getOrDefault("node-" + ordinal, 0)) + .toArray(); + int minimumAllocationCount = Arrays.stream(remainingNodeShardCounts).min().orElse(0); + int maximumAllocationCount = Arrays.stream(remainingNodeShardCounts).max().orElse(Integer.MAX_VALUE); + assertThat(maximumAllocationCount - minimumAllocationCount, lessThanOrEqualTo(1)); + } + + private Map shardCounts(ClusterState clusterState) { + Map shardCounts = new HashMap<>(); + clusterState.routingTable().allShards().forEach(sr -> shardCounts.compute(sr.currentNodeId(), (v, e) -> e == null ? 0 : e + 1)); + return shardCounts; + } + + private static ClusterState fullyReconcile(AllocationService allocationService, ClusterState clusterState) { + boolean changed; + do { + final var newState = startInitializingShardsAndReroute(allocationService, clusterState); + changed = newState != clusterState; + clusterState = newState; + } while (changed); + return clusterState; + } + public void testRebalance() { final var discoveryNodes = discoveryNodes(4); final var metadata = Metadata.builder();