@@ -644,18 +644,19 @@ def maybe_reshard_up(
644644 def wait_for_slices (
645645 self ,
646646 slice_count : int | None = None ,
647- wait_period : float | int = 10 ,
647+ poll_interval : float | int = 10 ,
648648 timeout : float | int | None = None ,
649649 ) -> set [int ]:
650650 """Waits until after at least `slice_count` slices become available.
651651
652652 Args:
653653 slice_count: The number of slices to wait for. If None, waits for all
654654 slices to become available.
655- wait_period: The number of seconds to wait between availability checks.
656- Defaults to 10 seconds.
657- timeout: The maximum number of seconds to wait. If None, there is
658- no timeout.
655+ poll_interval: The minimum number of seconds to wait between availability
656+ checks. If the check takes longer than this, the next check will start
657+ immediately after the current check completes. Defaults to 10 seconds.
658+ timeout: The maximum number of seconds to wait. If None, there is no
659+ timeout.
659660
660661 Returns:
661662 The good slice indices
@@ -669,42 +670,52 @@ def wait_for_slices(
669670
670671 start_time = time .time ()
671672
672- while (
673- len ( good_slice_indices := self . get_slice_availability ()) < slice_count
674- ):
675- elapsed_time = time . time () - start_time
676- if timeout is not None and elapsed_time + wait_period >= timeout :
677- raise TimeoutError (
678- f"Timed out waiting for { slice_count } slices. Only"
679- f" { len ( good_slice_indices ) } available after"
680- f" { elapsed_time :.2f } seconds."
681- f" Next check would occur after the timeout of { timeout } "
682- " seconds."
673+ while True :
674+ check_start_time = time . time ()
675+
676+ if (
677+ len ( good_slice_indices := self . get_slice_availability ())
678+ >= slice_count
679+ ):
680+ _logger . info (
681+ "%s/%s slices are available" ,
682+ len ( good_slice_indices ),
683+ self . total_slice_count ,
683684 )
685+ return good_slice_indices
684686
685687 _logger .info (
686- "%s/%s slices available. Wanting at least %s/%s. Sleeping for %s"
687- " seconds." ,
688+ "%s/%s slices available. Wanting at least %s/%s." ,
688689 len (good_slice_indices ),
689690 self .total_slice_count ,
690691 slice_count ,
691692 self .total_slice_count ,
692- wait_period ,
693693 )
694- time .sleep (wait_period )
695694
696- _logger .info (
697- "%s/%s slices are available" ,
698- len (good_slice_indices ),
699- self .total_slice_count ,
700- )
695+ time_to_sleep = max (0 , poll_interval - (time .time () - check_start_time ))
701696
702- return good_slice_indices
697+ if (
698+ timeout is not None
699+ and (elapsed_time := time .time () - start_time ) + time_to_sleep
700+ >= timeout
701+ ):
702+ raise TimeoutError (
703+ f"Timed out waiting for { slice_count } slices. Only"
704+ f" { len (good_slice_indices )} available after"
705+ f" { elapsed_time :.2f} seconds."
706+ f" Next check would occur after the timeout of { timeout } "
707+ " seconds."
708+ )
709+
710+ if time_to_sleep > 0 :
711+ _logger .info ("Sleeping for %s seconds." , time_to_sleep )
712+
713+ time .sleep (time_to_sleep )
703714
704715 def pause_resume (
705716 self ,
706717 max_retries : int ,
707- wait_period : float | int = 10 ,
718+ poll_interval : float | int = 10 ,
708719 timeout : float | None = None ,
709720 ) -> Any :
710721 """Retries a function with pause/resume fault tolerance.
@@ -723,7 +734,7 @@ def pause_resume(
723734
724735 Args:
725736 max_retries: The maximum number of times to retry the function.
726- wait_period : The number of seconds to wait between availability checks.
737+ poll_interval : The number of seconds to wait between availability checks.
727738 Defaults to 10 seconds.
728739 timeout: The maximum number of seconds to wait for slices to become
729740 available before each retry attempt. If None, there is no timeout.
@@ -745,7 +756,7 @@ def wrapper(*args, **kwargs):
745756 "Elastic attempt %d out of %d" , retry_index + 1 , max_retries
746757 )
747758
748- self .wait_for_slices (wait_period = wait_period , timeout = timeout )
759+ self .wait_for_slices (poll_interval = poll_interval , timeout = timeout )
749760
750761 return func (* args , ** kwargs )
751762 except jax .errors .JaxRuntimeError as error :
0 commit comments