Skip to content

Commit 94aa57c

Browse files
Add support for torch 2.11 and jax 0.9.1/0.9.2, remove support for jax 0.6.0/0.6.1 (#213)
1 parent ca3d7bd commit 94aa57c

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

.github/workflows/build-jax-wheels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
runs-on: ${{ matrix.os }}
5252
strategy:
5353
matrix:
54-
jax-version: ["0.6.0", "0.6.1", "0.6.2", "0.7.0", "0.7.1", "0.7.2", "0.8.0", "0.8.1", "0.8.2", "0.8.3", "0.9.0"]
54+
jax-version: ["0.6.2", "0.7.0", "0.7.1", "0.7.2", "0.8.0", "0.8.1", "0.8.2", "0.8.3", "0.9.0", "0.9.1", "0.9.2"]
5555
os: [ubuntu-22.04, macos-14, ubuntu-22.04-arm]
5656
include:
5757
- name: x86_64 Linux

.github/workflows/build-torch-wheels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
runs-on: ${{ matrix.os }}
5252
strategy:
5353
matrix:
54-
pytorch-version: ["2.3", "2.4", "2.5", "2.6", "2.7", "2.8", "2.9", "2.10"]
54+
pytorch-version: ["2.3", "2.4", "2.5", "2.6", "2.7", "2.8", "2.9", "2.10", "2.11"]
5555
os: [ubuntu-22.04, macos-14, ubuntu-22.04-arm]
5656
include:
5757
- os: ubuntu-22.04

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ torch
88

99
# jax[cpu], because python -m pip install jax, which would be triggered
1010
# by the main package's dependencies, does not install jaxlib
11-
jax[cpu] <0.6
11+
jax[cpu]
1212

1313
# metatensor and metatensor-torch for the metatensor API
1414
metatensor-torch

tox.ini

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ commands =
9696

9797
# Install this one manually. Listing it in the deps list above does not install jaxlib.
9898
# Note: jax[cuda12] is not available on Windows and MacOS.
99-
bash -c 'command -v nvcc && python -m pip install -U "jax[cuda12]<0.7.2" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || python -m pip install -U "jax[cpu]<0.7.2"'
99+
bash -c 'command -v nvcc && python -m pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || python -m pip install -U "jax[cpu]"'
100100

101101
pip install {[testenv]pip_install_flags} .
102102
pytest python {posargs}
@@ -112,7 +112,7 @@ deps =
112112
absl-py # jax uses this package but not installed by jax[cpu]
113113
pytest
114114
equinox
115-
jax[cpu]==0.6.0
115+
jax[cpu]==0.6.2
116116

117117
allowlist_externals =
118118
bash
@@ -140,7 +140,7 @@ deps =
140140
commands =
141141
# Install this one manually. Listing it in the deps list above does not install jaxlib.
142142
# Note: jax[cuda12] is not available on Windows and MacOS.
143-
bash -c 'command -v nvcc && python -m pip install -U "jax[cuda12]<0.7.2" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || python -m pip install -U "jax[cpu]<0.7.2"'
143+
bash -c 'command -v nvcc && python -m pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || python -m pip install -U "jax[cpu]"'
144144

145145
pip install {[testenv]pip_install_flags} .
146146
pip install {[testenv]pip_install_flags} ./sphericart-torch

0 commit comments

Comments
 (0)