Skip to content

[Bug]: Proposal: Check JAX Version Compatibility Before Calling JAX Functionalities #5381

@thegialeo

Description

@thegialeo

PyBaMM Version

25.12.2

Python Version

3.12

Describe the bug

First of all, thank you for the great work on PyBaMM. I’d like to propose an improvement regarding how JAX availability is handled.

Summary

Currently, has_jax() in util.py only checks whether jax is installed, but not which version is installed. This can lead to unexpected import-time failures when an incompatible JAX version is present in the environment.

If PyBaMM is installed with the pybamm[jax] extra, dependency resolvers (e.g., Poetry, etc.) correctly enforce the required JAX version constraints. However, when installing PyBaMM without the JAX extra, there is an assumption (both by users and dependency resolvers) that PyBaMM is completely independent of JAX.

In practice, this is not entirely true. If JAX is already installed in the environment (e.g., via a transitive dependency from another library or as a direct dependency), PyBaMM may still attempt to call JAX functionality and break the existing codebase.

Minimal Reproduction

conda create -n my_env python=3.12
conda activate my_env
pip install pybamm
pip install jax

Then:

import pybamm

This results in:

AttributeError: module 'jax.lib' has no attribute 'xla_bridge'

This appears to happen because PyBaMM assumes an older JAX API (jax>=0.4.36,<0.7.0), while newer JAX releases introduce changes that may not be fully compatible with that expected API.

Problem Description

  • Users installing PyBaMM without [jax] reasonably expect no interaction with JAX.
  • If JAX is present in the environment (even indirectly), PyBaMM may attempt to use it without verifying compatibility.
  • The resulting error is non-obvious and can take time to diagnose, especially for new users unfamiliar with PyBaMM’s optional JAX integration.
  • This creates unnecessary friction and can give the impression that PyBaMM is unstable, when the root cause is simply a version mismatch.

In my case, switching to the main branch resolved the issue because the JAX version requirement has been bumped there. However, I think it would still be beneficial to explicitly check the JAX version and gracefully ignore JAX support if the installed version is incompatible.

Proposal

Instead of only checking whether JAX is installed, has_jax() in util.py could additionally validate that the installed version satisfies PyBaMM’s supported version range.

If not compatible, PyBaMM could either:

  • Disable JAX support (i.e. treat it as has_jax() == False)
  • Raise a clear and informative error message explaining the version mismatch, explicitly stating that PyBaMM has a JAX version dependency whenever JAX is present in the environment even if PyBaMM was installed without the [jax] extra and clearly informing the user that a JAX installation was detected and a specific version of JAX is assumed.

As a user, I would prefer the former option.

This would:

  • Make the optional dependency behave more like a truly optional dependency
  • Avoid hard import-time crashes
  • Provide clearer guidance to users

Steps to Reproduce

conda create -n my_env python=3.12
conda activate my_env
pip install pybamm
pip install jax

Then:

import pybamm

Relevant log output

AttributeError: module 'jax.lib' has no attribute 'xla_bridge'

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdifficulty: easyA good issue for someone new. Can be done in a few hoursgood first issueIssues suitable for newcomers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions