Skip to content

Commit 3be5b6d

Browse files
committed
Try pmap without a lambda to see if this allows for cache hits
1 parent b72729b commit 3be5b6d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

pathwaysutils/elastic/manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
_logger = logging.getLogger(__name__)
4545

4646

47+
def _plus_one(x):
48+
return x + 1
49+
50+
4751
class ElasticRuntimeError(RuntimeError):
4852
"""Error raised when elasticity cannot continue.
4953
@@ -234,7 +238,7 @@ def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array:
234238
self._SIMPLE_EXECUTION_TEST_VALUE - 1
235239
)
236240

237-
return jax.pmap(lambda x: x + 1, devices=devices)(test_input)
241+
return jax.pmap(_plus_one, devices=devices)(test_input)
238242

239243
@timing.timeit
240244
def get_slice_availability(self) -> set[int]:

0 commit comments

Comments
 (0)