Commit e8adf9b
Extract lambda to a named function to ensure cache hits.
The lambda function used in `jax.pmap` within `_test_slice_execution` is moved into a private helper function `_plus_one`. lambda functions have a different hash and were not utilizing the compilation cache.
https://docs.jax.dev/en/latest/jit-compilation.html#jit-and-caching
FUTURE_COPYBARA_INTEGRATE_REVIEW=#122 from AI-Hypercomputer:prepare_for_version 53e7715
PiperOrigin-RevId: 8339363481 parent 2e79e31 commit e8adf9b
1 file changed
+15
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
44 | 44 | | |
45 | 45 | | |
46 | 46 | | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
47 | 61 | | |
48 | 62 | | |
49 | 63 | | |
| |||
234 | 248 | | |
235 | 249 | | |
236 | 250 | | |
237 | | - | |
| 251 | + | |
238 | 252 | | |
239 | 253 | | |
240 | 254 | | |
| |||
0 commit comments