Skip to content

Commit 1cadd79

Browse files
authored
Merge branch 'main' into features/pytorch_precompute_transforms
2 parents 76bc05e + b82e065 commit 1cadd79

20 files changed

+62
-41
lines changed

.all-contributorsrc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,24 @@
7575
"contributions": [
7676
"doc"
7777
]
78+
},
79+
{
80+
"login": "kmulderdas",
81+
"name": "Kevin Mulder",
82+
"avatar_url": "https://avatars.githubusercontent.com/u/33317219?v=4",
83+
"profile": "https://github.com/kmulderdas",
84+
"contributions": [
85+
"bug"
86+
]
87+
},
88+
{
89+
"login": "PhilippMisofCH",
90+
"name": "Philipp Misof",
91+
"avatar_url": "https://avatars.githubusercontent.com/u/142883157?v=4",
92+
"profile": "https://github.com/PhilippMisofCH",
93+
"contributions": [
94+
"bug"
95+
]
7896
}
7997
],
8098
"contributorsPerLine": 7,

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[![image](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
44
[![image](https://badge.fury.io/py/s2fft.svg)](https://badge.fury.io/py/s2fft)
55
[![image](http://img.shields.io/badge/arXiv-2311.14670-orange.svg?style=flat)](https://arxiv.org/abs/2311.14670)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
6-
[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-)<!-- ALL-CONTRIBUTORS-BADGE:END -->
6+
[![All Contributors](https://img.shields.io/badge/all_contributors-9-orange.svg?style=flat-square)](#contributors-)<!-- ALL-CONTRIBUTORS-BADGE:END -->
77
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing)
88
<!-- [![image](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -->
99

@@ -174,6 +174,10 @@ key](https://allcontributors.org/docs/en/emoji-key)):
174174
<td align="center" valign="top" width="14.28%"><a href="http://flanusse.net"><img src="https://avatars.githubusercontent.com/u/861591?v=4?s=100" width="100px;" alt="Francois Lanusse"/><br /><sub><b>Francois Lanusse</b></sub></a><br /><a href="https://github.com/astro-informatics/s2fft/commits?author=EiffL" title="Code">💻</a> <a href="https://github.com/astro-informatics/s2fft/issues?q=author%3AEiffL" title="Bug reports">🐛</a></td>
175175
<td align="center" valign="top" width="14.28%"><a href="https://github.com/eltociear"><img src="https://avatars.githubusercontent.com/u/22633385?v=4?s=100" width="100px;" alt="Ikko Eltociear Ashimine"/><br /><sub><b>Ikko Eltociear Ashimine</b></sub></a><br /><a href="https://github.com/astro-informatics/s2fft/commits?author=eltociear" title="Documentation">📖</a></td>
176176
</tr>
177+
<tr>
178+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/kmulderdas"><img src="https://avatars.githubusercontent.com/u/33317219?v=4?s=100" width="100px;" alt="Kevin Mulder"/><br /><sub><b>Kevin Mulder</b></sub></a><br /><a href="https://github.com/astro-informatics/s2fft/issues?q=author%3Akmulderdas" title="Bug reports">🐛</a></td>
179+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/PhilippMisofCH"><img src="https://avatars.githubusercontent.com/u/142883157?v=4?s=100" width="100px;" alt="Philipp Misof"/><br /><sub><b>Philipp Misof</b></sub></a><br /><a href="https://github.com/astro-informatics/s2fft/issues?q=author%3APhilippMisofCH" title="Bug reports">🐛</a></td>
180+
</tr>
177181
</tbody>
178182
</table>
179183

notebooks/custom_gradients.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
"os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
2020
"os.environ['JAX_CHECK_TRACER_LEAKS'] = 'True'\n",
2121
"\n",
22-
"from jax.config import config\n",
23-
"config.update(\"jax_enable_x64\", True)\n",
22+
"import jax\n",
23+
"jax.config.update(\"jax_enable_x64\", True)\n",
2424
"\n",
2525
"# Check we're running on GPU\n",
2626
"from jax.lib import xla_bridge\n",
2727
"print(xla_bridge.get_backend().platform)\n",
2828
"\n",
29-
"import jax\n",
3029
"from jax import jit, grad \n",
3130
"import jax.numpy as jnp \n",
3231
"from jax.test_util import check_grads\n",
@@ -98,7 +97,7 @@
9897
"name": "python",
9998
"nbconvert_exporter": "python",
10099
"pygments_lexer": "ipython3",
101-
"version": "3.9.0"
100+
"version": "3.10.4"
102101
},
103102
"orig_nbformat": 4,
104103
"vscode": {

notebooks/spherical_harmonic_transform.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
"metadata": {},
2929
"outputs": [],
3030
"source": [
31-
"from jax.config import config\n",
32-
"config.update(\"jax_enable_x64\", True)\n",
31+
"import jax\n",
32+
"jax.config.update(\"jax_enable_x64\", True)\n",
3333
"\n",
3434
"import numpy as np\n",
3535
"import s2fft \n",
@@ -199,7 +199,7 @@
199199
],
200200
"metadata": {
201201
"kernelspec": {
202-
"display_name": "Python 3.8.16 64-bit ('s2fft')",
202+
"display_name": "Python 3.10.4 ('s2fft')",
203203
"language": "python",
204204
"name": "python3"
205205
},
@@ -213,12 +213,12 @@
213213
"name": "python",
214214
"nbconvert_exporter": "python",
215215
"pygments_lexer": "ipython3",
216-
"version": "3.8.16"
216+
"version": "3.10.4"
217217
},
218218
"orig_nbformat": 4,
219219
"vscode": {
220220
"interpreter": {
221-
"hash": "d6019e21eb0d27eebd69283f1089b8b605b46cb058a452b887458f3af7017e46"
221+
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
222222
}
223223
}
224224
},

notebooks/spherical_rotation.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
"metadata": {},
3131
"outputs": [],
3232
"source": [
33-
"from jax.config import config\n",
34-
"config.update(\"jax_enable_x64\", True)\n",
33+
"import jax\n",
34+
"jax.config.update(\"jax_enable_x64\", True)\n",
3535
"\n",
3636
"import numpy as np\n",
3737
"import s2fft \n",

notebooks/wigner_transform.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
"metadata": {},
2929
"outputs": [],
3030
"source": [
31-
"from jax.config import config\n",
32-
"config.update(\"jax_enable_x64\", True) \n",
31+
"import jax\n",
32+
"jax.config.update(\"jax_enable_x64\", True)\n",
3333
"\n",
3434
"import numpy as np\n",
3535
"import s2fft \n",

s2fft/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from .utils.rotation import rotate_flms, generate_rotate_dls
1111

1212
import logging
13-
from jax.config import config
13+
import jax
1414

15-
if config.read("jax_enable_x64") is False:
15+
if jax.config.read("jax_enable_x64") is False:
1616
logger = logging.getLogger("s2fft")
1717
logger.warning(
1818
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L."

s2fft/precompute_transforms/construct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from jax import config
1+
import jax
22

3-
config.update("jax_enable_x64", True)
3+
jax.config.update("jax_enable_x64", True)
44

55
import numpy as np
66
import jax.numpy as jnp

s2fft/recursions/risbo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
55
r"""Compute Wigner-d at argument :math:`\beta` for full plane using
66
Risbo recursion.
77
8-
The Wigner-d plane is computed by recursion over :math:`\ell` (`el`).
8+
The Wigner-d plane is computed by recursion over :math:`\ell`.
99
Thus, for :math:`\ell > 0` the plane must be computed already for
1010
:math:`\ell - 1`. At present, for :math:`\ell = 0` the recusion is initialised.
1111
@@ -19,7 +19,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
1919
el (int): Spherical harmonic degree :math:`\ell`.
2020
2121
Returns:
22-
np.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
22+
np.ndarray: Plane of Wigner-d for :math:`\ell` and :math:`\beta`, with full plane computed.
2323
"""
2424

2525
_arg_checks(dl, beta, L, el)
@@ -103,7 +103,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
103103

104104

105105
def _arg_checks(dl: np.ndarray, beta: float, L: int, el: int):
106-
"""Check arguments of Risbo functions.
106+
r"""Check arguments of Risbo functions.
107107
108108
Args:
109109
dl (np.ndarray): Wigner-d plane of which to check shape.

s2fft/transforms/spherical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def f_bwd(res, gtm):
296296
ftm = ftm.at[:, m_start_ind + m_offset :].multiply(phase_shifts)
297297

298298
# Perform longitundal Fast Fourier Transforms
299-
ftm *= (-1) ** spin
299+
ftm *= (-1) ** jnp.abs(spin)
300300
if reality:
301301
ftm = ftm.at[:, m_offset : L - 1 + m_offset].set(
302302
jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
@@ -657,4 +657,4 @@ def f_bwd(res, glm):
657657
indices = jnp.repeat(jnp.expand_dims(jnp.arange(L), -1), 2 * L - 1, axis=-1)
658658
flm = jnp.where(indices < abs(spin), jnp.zeros_like(flm), flm[..., :])
659659

660-
return flm * (-1) ** spin
660+
return flm * (-1) ** jnp.abs(spin)

0 commit comments

Comments
 (0)