Skip to content

Commit 4accd61

Browse files
lukebaumanncopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 798360103
1 parent b70421e commit 4accd61

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

pathwaysutils/elastic/manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ def __init__(
8282
max_elastic_down_event_count: int | None = None,
8383
max_reshard_retry_count: int | None = None,
8484
) -> None:
85+
"""Initializes the manager.
86+
87+
Args:
88+
devices: The devices to use. If None, jax.devices() is used.
89+
reshard_check_period: The number of steps between reshard checks after a
90+
slice down event has occurred.
91+
snapshot_period: The number of steps between snapshots.
92+
max_elastic_down_event_count: The maximum number of elastic down events.
93+
If None, there is no limit.
94+
max_reshard_retry_count: The maximum number of consequetive reshard
95+
retries. If None, there is no limit.
96+
"""
8597
if devices is None:
8698
devices = jax.devices()
8799
self.devices = devices
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""A simulated manager for elastic training.
15+
16+
This module provides a simulated manager for elastic training. It can be used
17+
to test elastic training without needing to actually trigger elastic events.
18+
Instead, the user can control which slices are available at what times by
19+
calling `update_good_slice_indices`.
20+
"""
21+
22+
import logging
23+
from typing import Sequence
24+
25+
import jax
26+
from pathwaysutils.debug import timing
27+
from pathwaysutils.elastic import manager
28+
29+
30+
_logger = logging.getLogger(__name__)
31+
32+
33+
class SimulatedManager(manager.Manager):
34+
"""An elastic manager with settable slice availability.
35+
36+
This class can be used to modify which slices are marked as available by
37+
overloading the `get_slice_availability` function.
38+
"""
39+
40+
_simulated_good_slice_indices: set[int]
41+
42+
def __init__(
43+
self,
44+
devices: Sequence[jax.Device],
45+
reshard_check_period: int = 1,
46+
snapshot_period: int = 1,
47+
max_elastic_down_event_count: int | None = None,
48+
max_reshard_retry_count: int | None = None,
49+
) -> None:
50+
"""Initializes the simulated manager.
51+
52+
Args:
53+
devices: The devices to use. If None, jax.devices() is used.
54+
reshard_check_period: The number of steps between reshard checks after a
55+
slice down event has occurred.
56+
snapshot_period: The number of steps between snapshots.
57+
max_elastic_down_event_count: The maximum number of elastic down events.
58+
If None, there is no limit.
59+
max_reshard_retry_count: The maximum number of consequetive reshard
60+
retries. If None, there is no limit.
61+
"""
62+
self._simulated_good_slice_indices = set(d.slice_index for d in devices)
63+
64+
super().__init__(
65+
devices,
66+
snapshot_period,
67+
reshard_check_period,
68+
max_elastic_down_event_count,
69+
max_reshard_retry_count,
70+
)
71+
72+
def update_good_slice_indices(self, good_slice_indices: set[int]) -> None:
73+
"""Sets the good slice indices.
74+
75+
Subsequent calls to `get_slice_availability` will return these indices.
76+
77+
Args:
78+
good_slice_indices: The simulated good slice indices.
79+
"""
80+
self._simulated_good_slice_indices = good_slice_indices
81+
_logger.debug(
82+
"Updated: simumlated_good_slice_indices=%s",
83+
self._simulated_good_slice_indices,
84+
)
85+
86+
@timing.timeit
87+
def get_slice_availability(self) -> set[int]:
88+
"""Returns the set of good slice indices.
89+
90+
Returns:
91+
The set of good slice indices from the last call to
92+
update_good_slice_indices. Returns an empty set if
93+
update_good_slice_indices has not been called.
94+
"""
95+
good_slice_indices = self._simulated_good_slice_indices
96+
97+
_logger.debug("good_slice_indices=%s", good_slice_indices)
98+
99+
return good_slice_indices

0 commit comments

Comments
 (0)