Skip to content

Commit e8adf9b

Browse files
lukebaumanncopybara-github
authored andcommitted
Extract lambda to a named function to ensure cache hits.
The lambda function used in `jax.pmap` within `_test_slice_execution` is moved into a private helper function `_plus_one`. lambda functions have a different hash and were not utilizing the compilation cache. https://docs.jax.dev/en/latest/jit-compilation.html#jit-and-caching FUTURE_COPYBARA_INTEGRATE_REVIEW=#122 from AI-Hypercomputer:prepare_for_version 53e7715 PiperOrigin-RevId: 833936348
1 parent 2e79e31 commit e8adf9b

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

pathwaysutils/elastic/manager.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@
4444
_logger = logging.getLogger(__name__)
4545

4646

47+
def _plus_one(x: jax.Array) -> jax.Array:
48+
"""Adds one to each element in the array.
49+
50+
Used to test if a slice is available.
51+
52+
Args:
53+
x: The array to add one to.
54+
55+
Returns:
56+
The array with one added to each element.
57+
"""
58+
return x + 1
59+
60+
4761
class ElasticRuntimeError(RuntimeError):
4862
"""Error raised when elasticity cannot continue.
4963
@@ -234,7 +248,7 @@ def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array:
234248
self._SIMPLE_EXECUTION_TEST_VALUE - 1
235249
)
236250

237-
return jax.pmap(lambda x: x + 1, devices=devices)(test_input)
251+
return jax.pmap(_plus_one, devices=devices)(test_input)
238252

239253
@timing.timeit
240254
def get_slice_availability(self) -> set[int]:

0 commit comments

Comments
 (0)