You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block.
The `jax` and `jaxlib` versions are incremented to `0.7` in order to support the new `--xla_backend_extra_options=xla_cpu_flatten_after_fusion` flag, and Python incremented to 3.11 to support jaxlib 0.7.0. This flag treats all functions wrapped with a `jax.jit` within a larger `jax.jit` as independently compilable subfunctions (ie. prevents compiler inlining). It is critical for preventing blowups of compile-times (XLA:CPU compile-times scale roughly quadratically with the number of ops otherwise).
Use of jax.lax.custom_root is temporarily disabled until a workaround is found for it significantly increasing compile times.
This increases simulation speed by around 20%, but at the cost of ~10% increases in compile-time
PiperOrigin-RevId: 782947706
0 commit comments