diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index 3986959..b4b2e59 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -44,6 +44,20 @@ _logger = logging.getLogger(__name__) +def _plus_one(x: jax.Array) -> jax.Array: + """Adds one to each element in the array. + + Used to test if a slice is available. + + Args: + x: The array to add one to. + + Returns: + The array with one added to each element. + """ + return x + 1 + + class ElasticRuntimeError(RuntimeError): """Error raised when elasticity cannot continue. @@ -234,7 +248,7 @@ def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array: self._SIMPLE_EXECUTION_TEST_VALUE - 1 ) - return jax.pmap(lambda x: x + 1, devices=devices)(test_input) + return jax.pmap(_plus_one, devices=devices)(test_input) @timing.timeit def get_slice_availability(self) -> set[int]: