Skip to content

Commit 16f0efc

Browse files
committed
update jax.config after 0.4.26 deprecation
1 parent 64e8b4d commit 16f0efc

14 files changed

+31
-34
lines changed

notebooks/custom_gradients.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
"os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
2020
"os.environ['JAX_CHECK_TRACER_LEAKS'] = 'True'\n",
2121
"\n",
22-
"from jax.config import config\n",
23-
"config.update(\"jax_enable_x64\", True)\n",
22+
"import jax\n",
23+
"jax.config.update(\"jax_enable_x64\", True)\n",
2424
"\n",
2525
"# Check we're running on GPU\n",
2626
"from jax.lib import xla_bridge\n",
2727
"print(xla_bridge.get_backend().platform)\n",
2828
"\n",
29-
"import jax\n",
3029
"from jax import jit, grad \n",
3130
"import jax.numpy as jnp \n",
3231
"from jax.test_util import check_grads\n",
@@ -98,7 +97,7 @@
9897
"name": "python",
9998
"nbconvert_exporter": "python",
10099
"pygments_lexer": "ipython3",
101-
"version": "3.9.0"
100+
"version": "3.10.4"
102101
},
103102
"orig_nbformat": 4,
104103
"vscode": {

notebooks/spherical_harmonic_transform.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
"metadata": {},
2929
"outputs": [],
3030
"source": [
31-
"from jax.config import config\n",
32-
"config.update(\"jax_enable_x64\", True)\n",
31+
"import jax\n",
32+
"jax.config.update(\"jax_enable_x64\", True)\n",
3333
"\n",
3434
"import numpy as np\n",
3535
"import s2fft \n",
@@ -199,7 +199,7 @@
199199
],
200200
"metadata": {
201201
"kernelspec": {
202-
"display_name": "Python 3.8.16 64-bit ('s2fft')",
202+
"display_name": "Python 3.10.4 ('s2fft')",
203203
"language": "python",
204204
"name": "python3"
205205
},
@@ -213,12 +213,12 @@
213213
"name": "python",
214214
"nbconvert_exporter": "python",
215215
"pygments_lexer": "ipython3",
216-
"version": "3.8.16"
216+
"version": "3.10.4"
217217
},
218218
"orig_nbformat": 4,
219219
"vscode": {
220220
"interpreter": {
221-
"hash": "d6019e21eb0d27eebd69283f1089b8b605b46cb058a452b887458f3af7017e46"
221+
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
222222
}
223223
}
224224
},

notebooks/spherical_rotation.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
"metadata": {},
3131
"outputs": [],
3232
"source": [
33-
"from jax.config import config\n",
34-
"config.update(\"jax_enable_x64\", True)\n",
33+
"import jax\n",
34+
"jax.config.update(\"jax_enable_x64\", True)\n",
3535
"\n",
3636
"import numpy as np\n",
3737
"import s2fft \n",

notebooks/wigner_transform.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
"metadata": {},
2929
"outputs": [],
3030
"source": [
31-
"from jax.config import config\n",
32-
"config.update(\"jax_enable_x64\", True) \n",
31+
"import jax\n",
32+
"jax.config.update(\"jax_enable_x64\", True)\n",
3333
"\n",
3434
"import numpy as np\n",
3535
"import s2fft \n",

s2fft/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from .utils.rotation import rotate_flms, generate_rotate_dls
1111

1212
import logging
13-
from jax.config import config
13+
import jax
1414

15-
if config.read("jax_enable_x64") is False:
15+
if jax.config.read("jax_enable_x64") is False:
1616
logger = logging.getLogger("s2fft")
1717
logger.warning(
1818
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L."

s2fft/precompute_transforms/construct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from jax import config
1+
import jax
22

3-
config.update("jax_enable_x64", True)
3+
jax.config.update("jax_enable_x64", True)
44

55
import numpy as np
66
import jax.numpy as jnp

tests/test_healpix_ffts.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
22
import healpy as hp
33
import pytest
4-
from jax import config
4+
import jax
5+
6+
jax.config.update("jax_enable_x64", True)
57
from s2fft.sampling import s2_samples as samples
68
from s2fft.utils.healpix_ffts import (
79
healpix_fft_jax,
@@ -11,9 +13,6 @@
1113
)
1214

1315

14-
config.update("jax_enable_x64", True)
15-
16-
1716
nside_to_test = [4, 5]
1817
reality_to_test = [False, True]
1918

tests/test_spherical_custom_grads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from jax import config
1+
import jax
22

3-
config.update("jax_enable_x64", True)
3+
jax.config.update("jax_enable_x64", True)
44
import pytest
55
import jax.numpy as jnp
66
from jax.test_util import check_grads

tests/test_spherical_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from jax import config
1+
import jax
22

3-
config.update("jax_enable_x64", True)
3+
jax.config.update("jax_enable_x64", True)
44
import pytest
55
import pyssht as ssht
66
import numpy as np

tests/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from jax.config import config
1+
import jax
22

3-
config.update("jax_enable_x64", True)
3+
jax.config.update("jax_enable_x64", True)
44
import pytest
55
import pyssht as ssht
66
import numpy as np

0 commit comments

Comments
 (0)