|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""A simulated manager for elastic training. |
| 15 | +
|
| 16 | +This module provides a simulated manager for elastic training. It can be used |
| 17 | +to test elastic training without needing to actually trigger elastic events. |
| 18 | +Instead, the user can control which slices are available at what times by |
| 19 | +calling `update_good_slice_indices`. |
| 20 | +""" |
| 21 | + |
| 22 | +import logging |
| 23 | +from typing import Sequence |
| 24 | + |
| 25 | +import jax |
| 26 | +from pathwaysutils.debug import timing |
| 27 | +from pathwaysutils.elastic import manager |
| 28 | + |
| 29 | + |
| 30 | +_logger = logging.getLogger(__name__) |
| 31 | + |
| 32 | + |
| 33 | +class SimulatedManager(manager.Manager): |
| 34 | + """An elastic manager with settable slice availability. |
| 35 | +
|
| 36 | + This class can be used to modify which slices are marked as available by |
| 37 | + overloading the `get_slice_availability` function. |
| 38 | + """ |
| 39 | + |
| 40 | + _simulated_good_slice_indices: set[int] |
| 41 | + |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + devices: Sequence[jax.Device], |
| 45 | + reshard_check_period: int = 1, |
| 46 | + snapshot_period: int = 1, |
| 47 | + max_elastic_down_event_count: int | None = None, |
| 48 | + max_reshard_retry_count: int | None = None, |
| 49 | + ) -> None: |
| 50 | + """Initializes the simulated manager. |
| 51 | +
|
| 52 | + Args: |
| 53 | + devices: The devices to use. If None, jax.devices() is used. |
| 54 | + reshard_check_period: The number of steps between reshard checks after a |
| 55 | + slice down event has occurred. |
| 56 | + snapshot_period: The number of steps between snapshots. |
| 57 | + max_elastic_down_event_count: The maximum number of elastic down events. |
| 58 | + If None, there is no limit. |
| 59 | + max_reshard_retry_count: The maximum number of consequetive reshard |
| 60 | + retries. If None, there is no limit. |
| 61 | + """ |
| 62 | + self._simulated_good_slice_indices = set(d.slice_index for d in devices) |
| 63 | + |
| 64 | + super().__init__( |
| 65 | + devices, |
| 66 | + snapshot_period, |
| 67 | + reshard_check_period, |
| 68 | + max_elastic_down_event_count, |
| 69 | + max_reshard_retry_count, |
| 70 | + ) |
| 71 | + |
| 72 | + def update_good_slice_indices(self, good_slice_indices: set[int]) -> None: |
| 73 | + """Sets the good slice indices. |
| 74 | +
|
| 75 | + Subsequent calls to `get_slice_availability` will return these indices. |
| 76 | +
|
| 77 | + Args: |
| 78 | + good_slice_indices: The simulated good slice indices. |
| 79 | + """ |
| 80 | + self._simulated_good_slice_indices = good_slice_indices |
| 81 | + _logger.debug( |
| 82 | + "Updated: simumlated_good_slice_indices=%s", |
| 83 | + self._simulated_good_slice_indices, |
| 84 | + ) |
| 85 | + |
| 86 | + @timing.timeit |
| 87 | + def get_slice_availability(self) -> set[int]: |
| 88 | + """Returns the set of good slice indices. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + The set of good slice indices from the last call to |
| 92 | + update_good_slice_indices. Returns an empty set if |
| 93 | + update_good_slice_indices has not been called. |
| 94 | + """ |
| 95 | + good_slice_indices = self._simulated_good_slice_indices |
| 96 | + |
| 97 | + _logger.debug("good_slice_indices=%s", good_slice_indices) |
| 98 | + |
| 99 | + return good_slice_indices |
0 commit comments