3838import  java .util .Set ;
3939import  java .util .TreeMap ;
4040import  java .util .TreeSet ;
41+ import  java .util .function .LongSupplier ;
4142import  java .util .function .Predicate ;
4243
4344import  static  java .util .stream .Collectors .toUnmodifiableSet ;
@@ -49,8 +50,8 @@ public class DesiredBalanceComputer {
4950
5051    private  static  final  Logger  logger  = LogManager .getLogger (DesiredBalanceComputer .class );
5152
52-     private  final  ThreadPool  threadPool ;
5353    private  final  ShardsAllocator  delegateAllocator ;
54+     private  final  LongSupplier  timeSupplierMillis ;
5455
5556    // stats 
5657    protected  final  MeanMetric  iterations  = new  MeanMetric ();
@@ -63,12 +64,28 @@ public class DesiredBalanceComputer {
6364        Setting .Property .NodeScope 
6465    );
6566
67+     public  static  final  Setting <TimeValue > MAX_BALANCE_COMPUTATION_TIME_DURING_INDEX_CREATION_SETTING  = Setting .timeSetting (
68+         "cluster.routing.allocation.desired_balance.max_balance_computation_time_during_index_creation" ,
69+         TimeValue .timeValueSeconds (1 ),
70+         Setting .Property .Dynamic ,
71+         Setting .Property .NodeScope 
72+     );
73+ 
6674    private  TimeValue  progressLogInterval ;
75+     private  long  maxBalanceComputationTimeDuringIndexCreationMillis ;
6776
6877    public  DesiredBalanceComputer (ClusterSettings  clusterSettings , ThreadPool  threadPool , ShardsAllocator  delegateAllocator ) {
69-         this .threadPool  = threadPool ;
78+         this (clusterSettings , delegateAllocator , threadPool ::relativeTimeInMillis );
79+     }
80+ 
81+     DesiredBalanceComputer (ClusterSettings  clusterSettings , ShardsAllocator  delegateAllocator , LongSupplier  timeSupplierMillis ) {
7082        this .delegateAllocator  = delegateAllocator ;
83+         this .timeSupplierMillis  = timeSupplierMillis ;
7184        clusterSettings .initializeAndWatch (PROGRESS_LOG_INTERVAL_SETTING , value  -> this .progressLogInterval  = value );
85+         clusterSettings .initializeAndWatch (
86+             MAX_BALANCE_COMPUTATION_TIME_DURING_INDEX_CREATION_SETTING ,
87+             value  -> this .maxBalanceComputationTimeDuringIndexCreationMillis  = value .millis ()
88+         );
7289    }
7390
7491    public  DesiredBalance  compute (
@@ -77,7 +94,6 @@ public DesiredBalance compute(
7794        Queue <List <MoveAllocationCommand >> pendingDesiredBalanceMoves ,
7895        Predicate <DesiredBalanceInput > isFresh 
7996    ) {
80- 
8197        if  (logger .isTraceEnabled ()) {
8298            logger .trace (
8399                "Recomputing desired balance for [{}]: {}, {}, {}, {}" ,
@@ -97,9 +113,10 @@ public DesiredBalance compute(
97113        final  var  changes  = routingAllocation .changes ();
98114        final  var  ignoredShards  = getIgnoredShardsWithDiscardedAllocationStatus (desiredBalanceInput .ignoredShards ());
99115        final  var  clusterInfoSimulator  = new  ClusterInfoSimulator (routingAllocation );
116+         DesiredBalance .ComputationFinishReason  finishReason  = DesiredBalance .ComputationFinishReason .CONVERGED ;
100117
101118        if  (routingNodes .size () == 0 ) {
102-             return  new  DesiredBalance (desiredBalanceInput .index (), Map .of ());
119+             return  new  DesiredBalance (desiredBalanceInput .index (), Map .of (),  Map . of (),  finishReason );
103120        }
104121
105122        // we assume that all ongoing recoveries will complete 
@@ -263,11 +280,12 @@ public DesiredBalance compute(
263280
264281        final  int  iterationCountReportInterval  = computeIterationCountReportInterval (routingAllocation );
265282        final  long  timeWarningInterval  = progressLogInterval .millis ();
266-         final  long  computationStartedTime  = threadPool . relativeTimeInMillis ();
283+         final  long  computationStartedTime  = timeSupplierMillis . getAsLong ();
267284        long  nextReportTime  = computationStartedTime  + timeWarningInterval ;
268285
269286        int  i  = 0 ;
270287        boolean  hasChanges  = false ;
288+         boolean  assignedNewlyCreatedPrimaryShards  = false ;
271289        while  (true ) {
272290            if  (hasChanges ) {
273291                // Not the first iteration, so every remaining unassigned shard has been ignored, perhaps due to throttling. We must bring 
@@ -293,6 +311,15 @@ public DesiredBalance compute(
293311                for  (final  var  shardRouting  : routingNode ) {
294312                    if  (shardRouting .initializing ()) {
295313                        hasChanges  = true ;
314+                         if  (shardRouting .primary ()
315+                             && shardRouting .unassignedInfo () != null 
316+                             && shardRouting .unassignedInfo ().reason () == UnassignedInfo .Reason .INDEX_CREATED ) {
317+                             // TODO: we could include more cases that would cause early publishing of desired balance in case of a long 
318+                             // computation. e.g.: 
319+                             // - unassigned search replicas in case the shard has no assigned shard replicas 
320+                             // - other reasons for an unassigned shard such as NEW_INDEX_RESTORED 
321+                             assignedNewlyCreatedPrimaryShards  = true ;
322+                         }
296323                        clusterInfoSimulator .simulateShardStarted (shardRouting );
297324                        routingNodes .startShard (shardRouting , changes , 0L );
298325                    }
@@ -301,14 +328,14 @@ public DesiredBalance compute(
301328
302329            i ++;
303330            final  int  iterations  = i ;
304-             final  long  currentTime  = threadPool . relativeTimeInMillis ();
331+             final  long  currentTime  = timeSupplierMillis . getAsLong ();
305332            final  boolean  reportByTime  = nextReportTime  <= currentTime ;
306333            final  boolean  reportByIterationCount  = i  % iterationCountReportInterval  == 0 ;
307334            if  (reportByTime  || reportByIterationCount ) {
308335                nextReportTime  = currentTime  + timeWarningInterval ;
309336            }
310337
311-             if  (hasChanges  ==  false ) {
338+             if  (hasComputationConverged ( hasChanges ,  i ) ) {
312339                logger .debug (
313340                    "Desired balance computation for [{}] converged after [{}] and [{}] iterations" ,
314341                    desiredBalanceInput .index (),
@@ -324,9 +351,25 @@ public DesiredBalance compute(
324351                    "Desired balance computation for [{}] interrupted after [{}] and [{}] iterations as newer cluster state received. " 
325352                        + "Publishing intermediate desired balance and restarting computation" ,
326353                    desiredBalanceInput .index (),
354+                     TimeValue .timeValueMillis (currentTime  - computationStartedTime ).toString (),
355+                     i 
356+                 );
357+                 finishReason  = DesiredBalance .ComputationFinishReason .YIELD_TO_NEW_INPUT ;
358+                 break ;
359+             }
360+ 
361+             if  (assignedNewlyCreatedPrimaryShards 
362+                 && currentTime  - computationStartedTime  >= maxBalanceComputationTimeDuringIndexCreationMillis ) {
363+                 logger .info (
364+                     "Desired balance computation for [{}] interrupted after [{}] and [{}] iterations " 
365+                         + "in order to not delay assignment of newly created index shards for more than [{}]. " 
366+                         + "Publishing intermediate desired balance and restarting computation" ,
367+                     desiredBalanceInput .index (),
368+                     TimeValue .timeValueMillis (currentTime  - computationStartedTime ).toString (),
327369                    i ,
328-                     TimeValue .timeValueMillis (currentTime  -  computationStartedTime ).toString ()
370+                     TimeValue .timeValueMillis (maxBalanceComputationTimeDuringIndexCreationMillis ).toString ()
329371                );
372+                 finishReason  = DesiredBalance .ComputationFinishReason .STOP_EARLY ;
330373                break ;
331374            }
332375
@@ -368,7 +411,12 @@ public DesiredBalance compute(
368411        }
369412
370413        long  lastConvergedIndex  = hasChanges  ? previousDesiredBalance .lastConvergedIndex () : desiredBalanceInput .index ();
371-         return  new  DesiredBalance (lastConvergedIndex , assignments , routingNodes .getBalanceWeightStatsPerNode ());
414+         return  new  DesiredBalance (lastConvergedIndex , assignments , routingNodes .getBalanceWeightStatsPerNode (), finishReason );
415+     }
416+ 
417+     // visible for testing 
418+     boolean  hasComputationConverged (boolean  hasRoutingChanges , int  currentIteration ) {
419+         return  hasRoutingChanges  == false ;
372420    }
373421
374422    private  static  Map <ShardId , ShardAssignment > collectShardAssignments (RoutingNodes  routingNodes ) {
0 commit comments