-
-
Notifications
You must be signed in to change notification settings - Fork 754
Description
PyBaMM Version
25.12.2
Python Version
3.12
Describe the bug
First of all, thank you for the great work on PyBaMM. I’d like to propose an improvement regarding how JAX availability is handled.
Summary
Currently, has_jax() in util.py only checks whether jax is installed, but not which version is installed. This can lead to unexpected import-time failures when an incompatible JAX version is present in the environment.
If PyBaMM is installed with the pybamm[jax] extra, dependency resolvers (e.g., Poetry, etc.) correctly enforce the required JAX version constraints. However, when installing PyBaMM without the JAX extra, there is an assumption (both by users and dependency resolvers) that PyBaMM is completely independent of JAX.
In practice, this is not entirely true. If JAX is already installed in the environment (e.g., via a transitive dependency from another library or as a direct dependency), PyBaMM may still attempt to call JAX functionality and break the existing codebase.
Minimal Reproduction
conda create -n my_env python=3.12
conda activate my_env
pip install pybamm
pip install jaxThen:
import pybammThis results in:
AttributeError: module 'jax.lib' has no attribute 'xla_bridge'
This appears to happen because PyBaMM assumes an older JAX API (jax>=0.4.36,<0.7.0), while newer JAX releases introduce changes that may not be fully compatible with that expected API.
Problem Description
- Users installing PyBaMM without
[jax]reasonably expect no interaction with JAX. - If JAX is present in the environment (even indirectly), PyBaMM may attempt to use it without verifying compatibility.
- The resulting error is non-obvious and can take time to diagnose, especially for new users unfamiliar with PyBaMM’s optional JAX integration.
- This creates unnecessary friction and can give the impression that PyBaMM is unstable, when the root cause is simply a version mismatch.
In my case, switching to the main branch resolved the issue because the JAX version requirement has been bumped there. However, I think it would still be beneficial to explicitly check the JAX version and gracefully ignore JAX support if the installed version is incompatible.
Proposal
Instead of only checking whether JAX is installed, has_jax() in util.py could additionally validate that the installed version satisfies PyBaMM’s supported version range.
If not compatible, PyBaMM could either:
- Disable JAX support (i.e. treat it as
has_jax() == False) - Raise a clear and informative error message explaining the version mismatch, explicitly stating that PyBaMM has a JAX version dependency whenever JAX is present in the environment even if PyBaMM was installed without the [jax] extra and clearly informing the user that a JAX installation was detected and a specific version of JAX is assumed.
As a user, I would prefer the former option.
This would:
- Make the optional dependency behave more like a truly optional dependency
- Avoid hard import-time crashes
- Provide clearer guidance to users
Steps to Reproduce
conda create -n my_env python=3.12
conda activate my_env
pip install pybamm
pip install jaxThen:
import pybammRelevant log output
AttributeError: module 'jax.lib' has no attribute 'xla_bridge'