Skip to content

Commit b82e065

Browse files
authored
Merge pull request #189 from astro-informatics/housekeeping/debug_183
Fix small bugs from issue 183
2 parents 6c997be + 16f0efc commit b82e065

18 files changed

+38
-41
lines changed

notebooks/custom_gradients.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
"os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
2020
"os.environ['JAX_CHECK_TRACER_LEAKS'] = 'True'\n",
2121
"\n",
22-
"from jax.config import config\n",
23-
"config.update(\"jax_enable_x64\", True)\n",
22+
"import jax\n",
23+
"jax.config.update(\"jax_enable_x64\", True)\n",
2424
"\n",
2525
"# Check we're running on GPU\n",
2626
"from jax.lib import xla_bridge\n",
2727
"print(xla_bridge.get_backend().platform)\n",
2828
"\n",
29-
"import jax\n",
3029
"from jax import jit, grad \n",
3130
"import jax.numpy as jnp \n",
3231
"from jax.test_util import check_grads\n",
@@ -98,7 +97,7 @@
9897
"name": "python",
9998
"nbconvert_exporter": "python",
10099
"pygments_lexer": "ipython3",
101-
"version": "3.9.0"
100+
"version": "3.10.4"
102101
},
103102
"orig_nbformat": 4,
104103
"vscode": {

notebooks/spherical_harmonic_transform.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
"metadata": {},
2929
"outputs": [],
3030
"source": [
31-
"from jax.config import config\n",
32-
"config.update(\"jax_enable_x64\", True)\n",
31+
"import jax\n",
32+
"jax.config.update(\"jax_enable_x64\", True)\n",
3333
"\n",
3434
"import numpy as np\n",
3535
"import s2fft \n",
@@ -199,7 +199,7 @@
199199
],
200200
"metadata": {
201201
"kernelspec": {
202-
"display_name": "Python 3.8.16 64-bit ('s2fft')",
202+
"display_name": "Python 3.10.4 ('s2fft')",
203203
"language": "python",
204204
"name": "python3"
205205
},
@@ -213,12 +213,12 @@
213213
"name": "python",
214214
"nbconvert_exporter": "python",
215215
"pygments_lexer": "ipython3",
216-
"version": "3.8.16"
216+
"version": "3.10.4"
217217
},
218218
"orig_nbformat": 4,
219219
"vscode": {
220220
"interpreter": {
221-
"hash": "d6019e21eb0d27eebd69283f1089b8b605b46cb058a452b887458f3af7017e46"
221+
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
222222
}
223223
}
224224
},

notebooks/spherical_rotation.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
"metadata": {},
3131
"outputs": [],
3232
"source": [
33-
"from jax.config import config\n",
34-
"config.update(\"jax_enable_x64\", True)\n",
33+
"import jax\n",
34+
"jax.config.update(\"jax_enable_x64\", True)\n",
3535
"\n",
3636
"import numpy as np\n",
3737
"import s2fft \n",

notebooks/wigner_transform.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
"metadata": {},
2929
"outputs": [],
3030
"source": [
31-
"from jax.config import config\n",
32-
"config.update(\"jax_enable_x64\", True) \n",
31+
"import jax\n",
32+
"jax.config.update(\"jax_enable_x64\", True)\n",
3333
"\n",
3434
"import numpy as np\n",
3535
"import s2fft \n",

s2fft/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from .utils.rotation import rotate_flms, generate_rotate_dls
1111

1212
import logging
13-
from jax.config import config
13+
import jax
1414

15-
if config.read("jax_enable_x64") is False:
15+
if jax.config.read("jax_enable_x64") is False:
1616
logger = logging.getLogger("s2fft")
1717
logger.warning(
1818
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L."

s2fft/precompute_transforms/construct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from jax import config
1+
import jax
22

3-
config.update("jax_enable_x64", True)
3+
jax.config.update("jax_enable_x64", True)
44

55
import numpy as np
66
import jax.numpy as jnp

s2fft/recursions/risbo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
55
r"""Compute Wigner-d at argument :math:`\beta` for full plane using
66
Risbo recursion.
77
8-
The Wigner-d plane is computed by recursion over :math:`\ell` (`el`).
8+
The Wigner-d plane is computed by recursion over :math:`\ell`.
99
Thus, for :math:`\ell > 0` the plane must be computed already for
1010
:math:`\ell - 1`. At present, for :math:`\ell = 0` the recusion is initialised.
1111
@@ -19,7 +19,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
1919
el (int): Spherical harmonic degree :math:`\ell`.
2020
2121
Returns:
22-
np.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
22+
np.ndarray: Plane of Wigner-d for :math:`\ell` and :math:`\beta`, with full plane computed.
2323
"""
2424

2525
_arg_checks(dl, beta, L, el)
@@ -103,7 +103,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
103103

104104

105105
def _arg_checks(dl: np.ndarray, beta: float, L: int, el: int):
106-
"""Check arguments of Risbo functions.
106+
r"""Check arguments of Risbo functions.
107107
108108
Args:
109109
dl (np.ndarray): Wigner-d plane of which to check shape.

s2fft/transforms/spherical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def f_bwd(res, gtm):
296296
ftm = ftm.at[:, m_start_ind + m_offset :].multiply(phase_shifts)
297297

298298
# Perform longitundal Fast Fourier Transforms
299-
ftm *= (-1) ** spin
299+
ftm *= (-1) ** jnp.abs(spin)
300300
if reality:
301301
ftm = ftm.at[:, m_offset : L - 1 + m_offset].set(
302302
jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
@@ -657,4 +657,4 @@ def f_bwd(res, glm):
657657
indices = jnp.repeat(jnp.expand_dims(jnp.arange(L), -1), 2 * L - 1, axis=-1)
658658
flm = jnp.where(indices < abs(spin), jnp.zeros_like(flm), flm[..., :])
659659

660-
return flm * (-1) ** spin
660+
return flm * (-1) ** jnp.abs(spin)

s2fft/transforms/wigner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def inverse_jax(
259259
def spherical_loop(n, args):
260260
fban, flmn, lrenorm, vsign, spins = args
261261
fban = fban.at[n].add(
262-
(-1) ** spins[n]
262+
(-1) ** jnp.abs(spins[n])
263263
* s2fft.inverse_jax(
264264
flmn[n],
265265
L,

s2fft/utils/rotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def rotate_flms(
3737
dl = (
3838
dl_array
3939
if dl_array != None
40-
else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.complex128)
40+
else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.float64)
4141
)
4242

4343
# Perform rotation

0 commit comments

Comments
 (0)