From 3b4cacdc8f52483bef4e8e9862d86b526d4d33e9 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 18 Nov 2025 14:49:15 -0800 Subject: [PATCH] Extract lambda to a named function to ensure cache hits. The lambda function used in `jax.pmap` within `_test_slice_execution` is moved into a private helper function `_plus_one`. lambda functions have a different hash and were not utilizing the compilation cache. https://docs.jax.dev/en/latest/jit-compilation.html#jit-and-caching PiperOrigin-RevId: 833979645 --- pathwaysutils/elastic/manager.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index 3986959..b4b2e59 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -44,6 +44,20 @@ _logger = logging.getLogger(__name__) +def _plus_one(x: jax.Array) -> jax.Array: + """Adds one to each element in the array. + + Used to test if a slice is available. + + Args: + x: The array to add one to. + + Returns: + The array with one added to each element. + """ + return x + 1 + + class ElasticRuntimeError(RuntimeError): """Error raised when elasticity cannot continue. @@ -234,7 +248,7 @@ def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array: self._SIMPLE_EXECUTION_TEST_VALUE - 1 ) - return jax.pmap(lambda x: x + 1, devices=devices)(test_input) + return jax.pmap(_plus_one, devices=devices)(test_input) @timing.timeit def get_slice_availability(self) -> set[int]: