Skip to content

Commit fbda0d8

Browse files
committed
pin lowest jax dependencies
1 parent ef0afbc commit fbda0d8

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

examples/jax/identity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __call__(self, x):
6969
ckpt_path = work_dir / "ckpt"
7070
checkpointer = ocp.StandardCheckpointer()
7171
checkpointer.save(str(ckpt_path.resolve()), params)
72-
checkpointer.wait_until_finished()
72+
checkpointer.close()
7373

7474
# %% [markdown]
7575
# ## 3. Configure JaxModel and Predict

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515

1616
[project.optional-dependencies]
1717
torch = ["torch>=2.4"]
18-
jax = ["jax>=0.4", "flax>=0.8", "orbax-checkpoint>=0.5"]
18+
jax = ["jax>=0.4.30", "flax>=0.8", "orbax-checkpoint>=0.5"]
1919
dev = [
2020
"pytest>=8.0",
2121
"pytest-cov>=5.0",

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_jax_models(save_type, out_range, tmp_path):
110110
params_path = tmp_path / "ckpt"
111111
checkpointer = ocp.StandardCheckpointer()
112112
checkpointer.save(str(params_path), params)
113-
checkpointer.wait_until_finished()
113+
checkpointer.close()
114114
elif save_type == "msgpack":
115115
from flax.serialization import to_bytes
116116

0 commit comments

Comments
 (0)