3535import org .elasticsearch .cluster .routing .ShardRoutingState ;
3636import org .elasticsearch .cluster .routing .allocation .AllocateUnassignedDecision ;
3737import org .elasticsearch .cluster .routing .allocation .RoutingAllocation ;
38+ import org .elasticsearch .cluster .routing .allocation .decider .AllocationDecider ;
3839import org .elasticsearch .cluster .routing .allocation .decider .AllocationDeciders ;
40+ import org .elasticsearch .cluster .routing .allocation .decider .Decision ;
3941import org .elasticsearch .cluster .routing .allocation .decider .ThrottlingAllocationDecider ;
4042import org .elasticsearch .common .UUIDs ;
4143import org .elasticsearch .common .settings .Settings ;
4951import org .elasticsearch .test .gateway .TestGatewayAllocator ;
5052
5153import java .util .Arrays ;
54+ import java .util .Collection ;
5255import java .util .Collections ;
5356import java .util .HashMap ;
5457import java .util .List ;
5558import java .util .Map ;
5659import java .util .Set ;
60+ import java .util .function .Function ;
5761import java .util .stream .Collector ;
5862import java .util .stream .Collectors ;
5963import java .util .stream .StreamSupport ;
@@ -611,6 +615,71 @@ public void testShardSizeDiscrepancyWithinIndex() {
611615 assertSame (clusterState , reroute (allocationService , clusterState ));
612616 }
613617
618+ public void testPartitionedClusterWithSeparateWeights () {
619+ var allocationService = new MockAllocationService (
620+ prefixAllocationDeciders (),
621+ new TestGatewayAllocator (),
622+ new BalancedShardsAllocator (
623+ BalancerSettings .DEFAULT ,
624+ TEST_WRITE_LOAD_FORECASTER ,
625+ new PrefixBalancingWeightsFactory (
626+ Map .of ("shardsOnly" , new WeightFunction (1 , 0 , 0 , 0 ), "weightsOnly" , new WeightFunction (0 , 0 , 1 , 0 ))
627+ )
628+ ),
629+ EmptyClusterInfoService .INSTANCE ,
630+ SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES
631+ );
632+
633+ var clusterState = applyStartedShardsUntilNoChange (
634+ createStateWithIndices (
635+ List .of ("shardsOnly-1" , "shardsOnly-2" , "weightsOnly-1" , "weightsOnly-2" ),
636+ shardId -> prefix (shardId .getIndexName ()) + "-1" ,
637+ anIndex ("weightsOnly-heavy-index" ).indexWriteLoadForecast (8.0 ),
638+ anIndex ("weightsOnly-light-index-1" ).indexWriteLoadForecast (1.0 ),
639+ anIndex ("weightsOnly-light-index-2" ).indexWriteLoadForecast (2.0 ),
640+ anIndex ("weightsOnly-light-index-3" ).indexWriteLoadForecast (3.0 ),
641+ anIndex ("weightsOnly-zero-write-load-index" ).indexWriteLoadForecast (0.0 ),
642+ anIndex ("weightsOnly-no-write-load-index" ),
643+ anIndex ("shardsOnly-heavy-index" ).indexWriteLoadForecast (8.0 ),
644+ anIndex ("shardsOnly-light-index-1" ).indexWriteLoadForecast (1.0 ),
645+ anIndex ("shardsOnly-light-index-2" ).indexWriteLoadForecast (2.0 ),
646+ anIndex ("shardsOnly-light-index-3" ).indexWriteLoadForecast (3.0 ),
647+ anIndex ("shardsOnly-zero-write-load-index" ).indexWriteLoadForecast (0.0 ),
648+ anIndex ("shardsOnly-no-write-load-index" )
649+ ),
650+ allocationService
651+ );
652+
653+ Map <String , Set <String >> shardsPerNode = getShardsPerNode (clusterState );
654+ Map <String , Set <String >> shardBalancedPartition = shardsPerNode .entrySet ()
655+ .stream ()
656+ .filter (e -> e .getKey ().startsWith ("shardsOnly" ))
657+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
658+ Map <String , Set <String >> weightBalancedPartition = shardsPerNode .entrySet ()
659+ .stream ()
660+ .filter (e -> e .getKey ().startsWith ("weightsOnly" ))
661+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
662+
663+ // The partition that balances on weights only is skewed
664+ assertThat (
665+ weightBalancedPartition .values (),
666+ containsInAnyOrder (
667+ Set .of ("weightsOnly-heavy-index" ),
668+ Set .of (
669+ "weightsOnly-light-index-1" ,
670+ "weightsOnly-light-index-2" ,
671+ "weightsOnly-light-index-3" ,
672+ "weightsOnly-zero-write-load-index" ,
673+ "weightsOnly-no-write-load-index"
674+ )
675+ )
676+ );
677+
678+ // The partition that balances on shard count only has an even distribution of shards
679+ assertThat (shardBalancedPartition .get ("shardsOnly-1" ), hasSize (3 ));
680+ assertThat (shardBalancedPartition .get ("shardsOnly-2" ), hasSize (3 ));
681+ }
682+
614683 private Map <String , Integer > getTargetShardPerNodeCount (IndexRoutingTable indexRoutingTable ) {
615684 var counts = new HashMap <String , Integer >();
616685 for (int shardId = 0 ; shardId < indexRoutingTable .size (); shardId ++) {
@@ -647,6 +716,14 @@ private static IndexMetadata.Builder anIndex(String name, Settings.Builder setti
647716 }
648717
649718 private static ClusterState createStateWithIndices (IndexMetadata .Builder ... indexMetadataBuilders ) {
719+ return createStateWithIndices (List .of ("node-1" , "node-2" ), shardId -> "node-1" , indexMetadataBuilders );
720+ }
721+
722+ private static ClusterState createStateWithIndices (
723+ List <String > nodeNames ,
724+ Function <ShardId , String > unbalancedAllocator ,
725+ IndexMetadata .Builder ... indexMetadataBuilders
726+ ) {
650727 var metadataBuilder = Metadata .builder ();
651728 var routingTableBuilder = RoutingTable .builder (TestShardRoutingRoleStrategies .DEFAULT_ROLE_ONLY );
652729 if (randomBoolean ()) {
@@ -663,19 +740,25 @@ private static ClusterState createStateWithIndices(IndexMetadata.Builder... inde
663740 var inSyncId = UUIDs .randomBase64UUID ();
664741 var indexMetadata = index .putInSyncAllocationIds (0 , Set .of (inSyncId )).build ();
665742 metadataBuilder .put (indexMetadata , false );
743+ ShardId shardId = new ShardId (indexMetadata .getIndex (), 0 );
666744 routingTableBuilder .add (
667745 IndexRoutingTable .builder (indexMetadata .getIndex ())
668746 .addShard (
669- shardRoutingBuilder (new ShardId ( indexMetadata . getIndex (), 0 ), "node-1" , true , ShardRoutingState .STARTED )
747+ shardRoutingBuilder (shardId , unbalancedAllocator . apply ( shardId ) , true , ShardRoutingState .STARTED )
670748 .withAllocationId (AllocationId .newInitializing (inSyncId ))
671749 .build ()
672750 )
673751 );
674752 }
675753 }
676754
755+ DiscoveryNodes .Builder discoveryNodesBuilder = DiscoveryNodes .builder ();
756+ for (String nodeName : nodeNames ) {
757+ discoveryNodesBuilder .add (newNode (nodeName ));
758+ }
759+
677760 return ClusterState .builder (ClusterName .DEFAULT )
678- .nodes (DiscoveryNodes . builder (). add ( newNode ( "node-1" )). add ( newNode ( "node-2" )) )
761+ .nodes (discoveryNodesBuilder )
679762 .metadata (metadataBuilder )
680763 .routingTable (routingTableBuilder )
681764 .build ();
@@ -712,4 +795,100 @@ private void addIndex(
712795 }
713796 routingTableBuilder .add (indexRoutingTableBuilder );
714797 }
798+
799+ /**
800+ * A {@link BalancingWeightsFactory} that assumes the cluster is partitioned by the prefix
801+ * of the node and shard names before the `-`.
802+ */
803+ class PrefixBalancingWeightsFactory implements BalancingWeightsFactory {
804+
805+ private final Map <String , WeightFunction > prefixWeights ;
806+
807+ PrefixBalancingWeightsFactory (Map <String , WeightFunction > prefixWeights ) {
808+ this .prefixWeights = prefixWeights ;
809+ }
810+
811+ @ Override
812+ public BalancingWeights create () {
813+ return new PrefixBalancingWeights ();
814+ }
815+
816+ class PrefixBalancingWeights implements BalancingWeights {
817+
818+ @ Override
819+ public WeightFunction weightFunctionForShard (ShardRouting shard ) {
820+ return prefixWeights .get (prefix (shard .getIndexName ()));
821+ }
822+
823+ @ Override
824+ public WeightFunction weightFunctionForNode (RoutingNode node ) {
825+ return prefixWeights .get (prefix (node .node ().getId ()));
826+ }
827+
828+ @ Override
829+ public NodeSorters createNodeSorters (
830+ BalancedShardsAllocator .ModelNode [] modelNodes ,
831+ BalancedShardsAllocator .Balancer balancer
832+ ) {
833+ final HashMap <String , BalancedShardsAllocator .NodeSorter > prefixNodeSorters = new HashMap <>();
834+ for (var entry : prefixWeights .entrySet ()) {
835+ prefixNodeSorters .put (
836+ entry .getKey (),
837+ new BalancedShardsAllocator .NodeSorter (
838+ Arrays .stream (modelNodes )
839+ .filter (node -> prefix (node .getRoutingNode ().node ().getId ()).equals (entry .getKey ()))
840+ .toArray (BalancedShardsAllocator .ModelNode []::new ),
841+ entry .getValue (),
842+ balancer
843+ )
844+ );
845+ }
846+ return new NodeSorters () {
847+ @ Override
848+ public Collection <BalancedShardsAllocator .NodeSorter > allNodeSorters () {
849+ return prefixNodeSorters .values ();
850+ }
851+
852+ @ Override
853+ public BalancedShardsAllocator .NodeSorter sorterForShard (ShardRouting shard ) {
854+ return prefixNodeSorters .get (prefix (shard .getIndexName ()));
855+ }
856+ };
857+ }
858+ }
859+ }
860+
861+ /**
862+ * Allocation deciders that only allow shards to be allocated to nodes whose names share the same prefix
863+ * as the index they're from
864+ */
865+ private AllocationDeciders prefixAllocationDeciders () {
866+ return new AllocationDeciders (List .of (new AllocationDecider () {
867+ @ Override
868+ public Decision canAllocate (ShardRouting shardRouting , RoutingNode node , RoutingAllocation allocation ) {
869+ return nodePrefixMatchesIndexPrefix (shardRouting , node );
870+ }
871+
872+ @ Override
873+ public Decision canRemain (
874+ IndexMetadata indexMetadata ,
875+ ShardRouting shardRouting ,
876+ RoutingNode node ,
877+ RoutingAllocation allocation
878+ ) {
879+ return nodePrefixMatchesIndexPrefix (shardRouting , node );
880+ }
881+
882+ private Decision nodePrefixMatchesIndexPrefix (ShardRouting shardRouting , RoutingNode node ) {
883+ var indexPrefix = prefix (shardRouting .index ().getName ());
884+ var nodePrefix = prefix (node .node ().getId ());
885+ return nodePrefix .equals (indexPrefix ) ? Decision .YES : Decision .NO ;
886+ }
887+ }));
888+ }
889+
890+ private static String prefix (String value ) {
891+ assert value != null && value .contains ("-" ) : "Invalid name passed: " + value ;
892+ return value .substring (0 , value .indexOf ("-" ));
893+ }
715894}
0 commit comments