Skip to content

Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block.#1412

Merged
copybara-service[bot] merged 1 commit intomainfrom
test_782947706
Aug 15, 2025
Merged

Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block.#1412
copybara-service[bot] merged 1 commit intomainfrom
test_782947706

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Aug 4, 2025

Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block.

The jax and jaxlib versions are incremented to 0.7 in order to support the new --xla_backend_extra_options=xla_cpu_flatten_after_fusion flag, and Python incremented to 3.11 to support jaxlib 0.7.0. This flag treats all functions wrapped with a jax.jit within a larger jax.jit as independently compilable subfunctions (ie. prevents compiler inlining). It is critical for preventing blowups of compile-times (XLA:CPU compile-times scale roughly quadratically with the number of ops otherwise).

Use of jax.lax.custom_root is temporarily disabled until a workaround is found for it significantly increasing compile times.

This increases simulation speed by around 20%, but at the cost of ~10% increases in compile-time

@copybara-service copybara-service bot changed the title Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson. Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block. Aug 12, 2025
@copybara-service copybara-service bot force-pushed the test_782947706 branch 7 times, most recently from 4452f16 to 18e5495 Compare August 15, 2025 10:56
…_newton_raphson`, and JIT the entire `newton_raphson_solve_block`.

The `jax` and `jaxlib` versions are incremented to `0.7` in order to support the new `--xla_backend_extra_options=xla_cpu_flatten_after_fusion` flag, and Python incremented to 3.11 to support jaxlib 0.7.0. This flag treats all functions wrapped with a `jax.jit` within a larger `jax.jit` as independently compilable subfunctions (ie. prevents compiler inlining). It is critical for preventing blowups of compile-times (XLA:CPU compile-times scale roughly quadratically with the number of ops otherwise).

Use of jax.lax.custom_root is temporarily disabled until a workaround is found for it significantly increasing compile times.

This increases simulation speed by around 20%, but at the cost of ~10% increases in compile-time

PiperOrigin-RevId: 795422861
@copybara-service copybara-service bot merged commit 0f206de into main Aug 15, 2025
@copybara-service copybara-service bot deleted the test_782947706 branch August 15, 2025 11:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant