Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block.#1412
Merged
copybara-service[bot] merged 1 commit intomainfrom Aug 15, 2025
Conversation
82cbb31 to
495d635
Compare
495d635 to
716c20e
Compare
newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block.
4452f16 to
18e5495
Compare
…_newton_raphson`, and JIT the entire `newton_raphson_solve_block`. The `jax` and `jaxlib` versions are incremented to `0.7` in order to support the new `--xla_backend_extra_options=xla_cpu_flatten_after_fusion` flag, and Python incremented to 3.11 to support jaxlib 0.7.0. This flag treats all functions wrapped with a `jax.jit` within a larger `jax.jit` as independently compilable subfunctions (ie. prevents compiler inlining). It is critical for preventing blowups of compile-times (XLA:CPU compile-times scale roughly quadratically with the number of ops otherwise). Use of jax.lax.custom_root is temporarily disabled until a workaround is found for it significantly increasing compile times. This increases simulation speed by around 20%, but at the cost of ~10% increases in compile-time PiperOrigin-RevId: 795422861
18e5495 to
0f206de
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Switch the
newton_raphson_solve_blockto usejax_root_finding.root_newton_raphson, and JIT the entirenewton_raphson_solve_block.The
jaxandjaxlibversions are incremented to0.7in order to support the new--xla_backend_extra_options=xla_cpu_flatten_after_fusionflag, and Python incremented to 3.11 to support jaxlib 0.7.0. This flag treats all functions wrapped with ajax.jitwithin a largerjax.jitas independently compilable subfunctions (ie. prevents compiler inlining). It is critical for preventing blowups of compile-times (XLA:CPU compile-times scale roughly quadratically with the number of ops otherwise).Use of jax.lax.custom_root is temporarily disabled until a workaround is found for it significantly increasing compile times.
This increases simulation speed by around 20%, but at the cost of ~10% increases in compile-time