Skip to content

Commit 3de0d8a

Browse files
Remove deprecated JAX funtions (#5203)
* Remove deprecated JAX funtions * nox changes * Try to temporarily use pybammsolvers from git branch * Try to make temporary fix for CI work * Revert non-JAX stuff, remove version pin * Try to make work again with old Jax versions on MacOS-13 * style: pre-commit fixes * Try with Jax 0.6 for now * Add unit test for split_list * style: pre-commit fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 327948a commit 3de0d8a

File tree

4 files changed

+67
-11
lines changed

4 files changed

+67
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ dev = [
113113
"hypothesis",
114114
]
115115
jax = [
116-
"jax>=0.4.36,<0.6.0",
116+
"jax>=0.4.36,<0.7.0",
117117
]
118118
# Contains all optional dependencies, except for jax, and dev dependencies
119119
all = [

src/pybamm/solvers/idaklu_jax.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,22 @@
1616
if pybamm.has_jax():
1717
import jax
1818
from jax import lax
19+
20+
try:
21+
from jax import ffi
22+
except ImportError:
23+
from jax.extend import ffi
1924
from jax import numpy as jnp
2025
from jax.interpreters import ad, batching, mlir
2126
from jax.interpreters.mlir import custom_call
22-
from jax.lib import xla_client
2327
from jax.tree_util import tree_flatten
2428

29+
# Handle JAX version compatibility for Primitive location
30+
try:
31+
from jax.core import Primitive
32+
except ImportError:
33+
from jax.extend.core import Primitive
34+
2535

2636
class IDAKLUJax:
2737
"""JAX wrapper for IDAKLU solver
@@ -600,15 +610,14 @@ def _jaxify(
600610
self._register_callbacks() # Register python methods as callbacks in IDAKLU-JAX
601611

602612
for _name, _value in idaklu.registrations().items():
603-
# todo: This has been removed from jax v0.6.0
604-
xla_client.register_custom_call_target(
605-
f"{_name}_{self._unique_name()}", _value, platform="cpu"
613+
ffi.register_ffi_target(
614+
f"{_name}_{self._unique_name()}", _value, platform="cpu", api_version=0
606615
)
607616

608617
# --- JAX PRIMITIVE DEFINITION ------------------------------------------------
609618

610619
logger.debug(f"Creating new primitive: {self._unique_name()}")
611-
f_p = jax.core.Primitive(f"f_{self._unique_name()}")
620+
f_p = Primitive(f"f_{self._unique_name()}")
612621
f_p.multiple_results = False # Returns a single multi-dimensional array
613622

614623
def f(t, inputs=None):
@@ -759,7 +768,7 @@ def make_zero(prim, tan):
759768

760769
ad.primitive_jvps[f_p] = f_jvp
761770

762-
f_jvp_p = jax.core.Primitive(f"f_jvp_{self._unique_name()}")
771+
f_jvp_p = Primitive(f"f_jvp_{self._unique_name()}")
763772

764773
@f_jvp_p.def_impl
765774
def f_jvp_eval(*args):
@@ -941,7 +950,7 @@ def f_jvp_lowering_cpu(ctx, *args):
941950

942951
# --- JAX PRIMITIVE VJP DEFINITION --------------------------------------------
943952

944-
f_vjp_p = jax.core.Primitive(f"f_vjp_{self._unique_name()}")
953+
f_vjp_p = Primitive(f"f_vjp_{self._unique_name()}")
945954

946955
def f_vjp(y_bar, invar, *primals):
947956
"""Main wrapper for the VJP function"""

src/pybamm/solvers/jax_bdf_solver.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import pybamm
99

1010
if pybamm.has_jax():
11+
import functools
12+
1113
import jax
1214
import jax.numpy as jnp
1315
from jax import core, dtypes
@@ -16,7 +18,19 @@
1618
from jax.flatten_util import ravel_pytree
1719
from jax.interpreters import partial_eval as pe
1820
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
19-
from jax.util import cache, safe_map, split_list
21+
22+
def split_list(lst, indices):
23+
"""Split a list at given indices."""
24+
if not indices:
25+
return [lst]
26+
27+
result = []
28+
start = 0
29+
for idx in indices:
30+
result.append(lst[start:idx])
31+
start = idx
32+
result.append(lst[start:])
33+
return result
2034

2135
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
2236
if platform != "metal":
@@ -875,7 +889,7 @@ def scan_fun(carry, i):
875889

876890
_bdf_odeint.defvjp(_bdf_odeint_fwd, _bdf_odeint_rev)
877891

878-
@cache()
892+
@functools.cache
879893
def closure_convert(fun, in_tree, in_avals):
880894
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
881895
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
@@ -975,7 +989,7 @@ def _check_arg(arg):
975989
raise TypeError(msg.format(arg))
976990

977991
flat_args, in_tree = tree_flatten((y0, t_eval[0], *args))
978-
in_avals = tuple(safe_map(abstractify, flat_args))
992+
in_avals = tuple(map(abstractify, flat_args))
979993
converted, consts = closure_convert(func, in_tree, in_avals)
980994
if mass is None:
981995
mass = onp.identity(y0.shape[0], dtype=y0.dtype)

tests/unit/test_solvers/test_jax_bdf_solver.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,36 @@ def fun(y, t, inputs):
175175
)
176176

177177
np.testing.assert_allclose(y[:, 0].reshape(-1), np.exp(-0.1 * t_eval))
178+
179+
def test_split_list(self):
180+
"""Test the split_list utility function."""
181+
from pybamm.solvers.jax_bdf_solver import split_list
182+
183+
# Test case 1: Empty indices should return the original list
184+
original_list = [1, 2, 3, 4, 5]
185+
result = split_list(original_list, [])
186+
assert result == [[1, 2, 3, 4, 5]]
187+
188+
# Test case 2: Single index should split at that point
189+
result = split_list([1, 2, 3, 4, 5], [3])
190+
assert result == [[1, 2, 3], [4, 5]]
191+
192+
# Test case 3: Multiple indices should create multiple sublists
193+
result = split_list([1, 2, 3, 4, 5, 6, 7], [2, 5])
194+
assert result == [[1, 2], [3, 4, 5], [6, 7]]
195+
196+
# Test case 4: Index at the beginning
197+
result = split_list([1, 2, 3, 4], [0])
198+
assert result == [[], [1, 2, 3, 4]]
199+
200+
# Test case 5: Index at the end
201+
result = split_list([1, 2, 3, 4], [4])
202+
assert result == [[1, 2, 3, 4], []]
203+
204+
# Test case 6: Empty list
205+
result = split_list([], [])
206+
assert result == [[]]
207+
208+
# Test case 7: String list to test with different data types
209+
result = split_list(["a", "b", "c", "d"], [2])
210+
assert result == [["a", "b"], ["c", "d"]]

0 commit comments

Comments
 (0)