Skip to content

Conversation

@copybara-service
Copy link

Extract lambda to a named function to ensure cache hits.

The lambda function used in jax.pmap within _test_slice_execution is moved into a private helper function _plus_one. lambda functions have a different hash and were not utilizing the compilation cache.

https://docs.jax.dev/en/latest/jit-compilation.html#jit-and-caching

FUTURE_COPYBARA_INTEGRATE_REVIEW=#122 from AI-Hypercomputer:prepare_for_version 53e7715

The lambda function used in `jax.pmap` within `_test_slice_execution` is moved into a private helper function `_plus_one`. lambda functions have a different hash and were not utilizing the compilation cache.

https://docs.jax.dev/en/latest/jit-compilation.html#jit-and-caching

PiperOrigin-RevId: 833979645
@copybara-service copybara-service bot merged commit 3b4cacd into main Nov 18, 2025
@copybara-service copybara-service bot deleted the test_833936348 branch November 18, 2025 22:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant