Skip to content

Commit 77ba076

Browse files
lukebaumanncopybara-github
authored andcommitted
Refactoring wait_for_slices
PiperOrigin-RevId: 798375418
1 parent 4accd61 commit 77ba076

File tree

1 file changed

+40
-29
lines changed

1 file changed

+40
-29
lines changed

pathwaysutils/elastic/manager.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -644,18 +644,19 @@ def maybe_reshard_up(
644644
def wait_for_slices(
645645
self,
646646
slice_count: int | None = None,
647-
wait_period: float | int = 10,
647+
poll_interval: float | int = 10,
648648
timeout: float | int | None = None,
649649
) -> set[int]:
650650
"""Waits until after at least `slice_count` slices become available.
651651
652652
Args:
653653
slice_count: The number of slices to wait for. If None, waits for all
654654
slices to become available.
655-
wait_period: The number of seconds to wait between availability checks.
656-
Defaults to 10 seconds.
657-
timeout: The maximum number of seconds to wait. If None, there is
658-
no timeout.
655+
poll_interval: The minimum number of seconds to wait between availability
656+
checks. If the check takes longer than this, the next check will start
657+
immediately after the current check completes. Defaults to 10 seconds.
658+
timeout: The maximum number of seconds to wait. If None, there is no
659+
timeout.
659660
660661
Returns:
661662
The good slice indices
@@ -669,42 +670,52 @@ def wait_for_slices(
669670

670671
start_time = time.time()
671672

672-
while (
673-
len(good_slice_indices := self.get_slice_availability()) < slice_count
674-
):
675-
elapsed_time = time.time() - start_time
676-
if timeout is not None and elapsed_time + wait_period >= timeout:
677-
raise TimeoutError(
678-
f"Timed out waiting for {slice_count} slices. Only"
679-
f" {len(good_slice_indices)} available after"
680-
f" {elapsed_time:.2f} seconds."
681-
f" Next check would occur after the timeout of {timeout}"
682-
" seconds."
673+
while True:
674+
check_start_time = time.time()
675+
676+
if (
677+
len(good_slice_indices := self.get_slice_availability())
678+
>= slice_count
679+
):
680+
_logger.info(
681+
"%s/%s slices are available",
682+
len(good_slice_indices),
683+
self.total_slice_count,
683684
)
685+
return good_slice_indices
684686

685687
_logger.info(
686-
"%s/%s slices available. Wanting at least %s/%s. Sleeping for %s"
687-
" seconds.",
688+
"%s/%s slices available. Wanting at least %s/%s.",
688689
len(good_slice_indices),
689690
self.total_slice_count,
690691
slice_count,
691692
self.total_slice_count,
692-
wait_period,
693693
)
694-
time.sleep(wait_period)
695694

696-
_logger.info(
697-
"%s/%s slices are available",
698-
len(good_slice_indices),
699-
self.total_slice_count,
700-
)
695+
time_to_sleep = max(0, poll_interval - (time.time() - check_start_time))
701696

702-
return good_slice_indices
697+
if (
698+
timeout is not None
699+
and (elapsed_time := time.time() - start_time) + time_to_sleep
700+
>= timeout
701+
):
702+
raise TimeoutError(
703+
f"Timed out waiting for {slice_count} slices. Only"
704+
f" {len(good_slice_indices)} available after"
705+
f" {elapsed_time:.2f} seconds."
706+
f" Next check would occur after the timeout of {timeout}"
707+
" seconds."
708+
)
709+
710+
if time_to_sleep > 0:
711+
_logger.info("Sleeping for %s seconds.", time_to_sleep)
712+
713+
time.sleep(time_to_sleep)
703714

704715
def pause_resume(
705716
self,
706717
max_retries: int,
707-
wait_period: float | int = 10,
718+
poll_interval: float | int = 10,
708719
timeout: float | None = None,
709720
) -> Any:
710721
"""Retries a function with pause/resume fault tolerance.
@@ -723,7 +734,7 @@ def pause_resume(
723734
724735
Args:
725736
max_retries: The maximum number of times to retry the function.
726-
wait_period: The number of seconds to wait between availability checks.
737+
poll_interval: The number of seconds to wait between availability checks.
727738
Defaults to 10 seconds.
728739
timeout: The maximum number of seconds to wait for slices to become
729740
available before each retry attempt. If None, there is no timeout.
@@ -745,7 +756,7 @@ def wrapper(*args, **kwargs):
745756
"Elastic attempt %d out of %d", retry_index + 1, max_retries
746757
)
747758

748-
self.wait_for_slices(wait_period=wait_period, timeout=timeout)
759+
self.wait_for_slices(poll_interval=poll_interval, timeout=timeout)
749760

750761
return func(*args, **kwargs)
751762
except jax.errors.JaxRuntimeError as error:

0 commit comments

Comments
 (0)