Skip to content

Commit a796dfe

Browse files
authored
Merge pull request #162 from astro-informatics/feature/custom_grads
Feature/custom grads
2 parents ba9d48b + 9d520f4 commit a796dfe

17 files changed

+564
-88
lines changed

.github/workflows/tests.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ jobs:
2525
with:
2626
python-version: ${{ matrix.python-version }}
2727

28-
# - name: Install dependencies
29-
# run: |
30-
# python -m pip install --upgrade pip
31-
# pip install -r requirements/requirements-tests.txt
32-
# pip install -r requirements/requirements-core.txt
33-
# pip install .
28+
- name: Install dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
pip install -r requirements/requirements-tests.txt
32+
pip install -r requirements/requirements-core.txt
33+
pip install .
3434
35-
# - name: Run tests
36-
# run: |
37-
# pytest --cov-report term --cov=s2wav --cov-config=.coveragerc
38-
# codecov --token 298dc7ee-bb9f-4221-b31f-3576cc6cb702
35+
- name: Run tests
36+
run: |
37+
pytest --cov-report term --cov=s2wav --cov-config=.coveragerc
38+
codecov --token 298dc7ee-bb9f-4221-b31f-3576cc6cb702

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pixels of equal areas, which has many practical advantages.
6060

6161
The Python dependencies for the `S2FFT` package are listed in the file
6262
`requirements/requirements-core.txt` and will be automatically installed
63-
into the active python environment by [pip]{.title-ref} when running
63+
into the active python environment by [pip](https://pypi.org) when running
6464

6565
``` bash
6666
pip install .
@@ -75,7 +75,7 @@ tox -e py38 # for tox
7575
```
7676

7777
In the very near future one will be able to install `S2FFT` directly
78-
from [PyPi]{.title-ref} by `pip install s2fft` but this is not yet
78+
from [PyPi](https://pypi.org) by `pip install s2fft` but this is not yet
7979
supported. Note that to run `JAX` on NVIDIA GPUs you will need to follow
8080
the [guide](https://github.com/google/jax#installation) outlined by
8181
Google.

notebooks/custom_gradients.ipynb

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"cpu\n"
13+
]
14+
}
15+
],
16+
"source": [
17+
"# Specify CUDA device\n",
18+
"import os\n",
19+
"os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
20+
"os.environ['JAX_CHECK_TRACER_LEAKS'] = 'True'\n",
21+
"\n",
22+
"from jax.config import config\n",
23+
"config.update(\"jax_enable_x64\", True)\n",
24+
"\n",
25+
"# Check we're running on GPU\n",
26+
"from jax.lib import xla_bridge\n",
27+
"print(xla_bridge.get_backend().platform)\n",
28+
"\n",
29+
"import jax\n",
30+
"from jax import jit, grad \n",
31+
"import jax.numpy as jnp \n",
32+
"from jax.test_util import check_grads\n",
33+
"import numpy as np \n",
34+
"\n",
35+
"import s2fft "
36+
]
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": 2,
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"L = 16\n",
45+
"sampling = \"mw\"\n",
46+
"np.random.seed(1911851)\n",
47+
"f_target = np.random.randn(2*L, 2*L-1)+1j*np.random.randn(2*L, 2*L-1)\n",
48+
"flm_target = s2fft.forward_jax(f_target, L, sampling=sampling)\n",
49+
"f_target = s2fft.inverse_jax(flm_target, L, sampling=sampling)\n",
50+
"precomps = s2fft.generate_precomputes_jax(L, forward=True, sampling=sampling)"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": 3,
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"np.random.seed(130672510)\n",
60+
"f = np.random.randn(2*L, 2*L-1) + 1j*np.random.randn(2*L, 2*L-1)"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": 4,
66+
"metadata": {},
67+
"outputs": [],
68+
"source": [
69+
"def func(f):\n",
70+
" flm = s2fft.forward_jax(f, L, reality=False, precomps=precomps,sampling=sampling)\n",
71+
" return jnp.sum(jnp.abs(flm-flm_target)**2)\n",
72+
"grad_func = grad(func)"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 5,
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"check_grads(func, (f,), order=1, modes=('rev'))"
82+
]
83+
}
84+
],
85+
"metadata": {
86+
"kernelspec": {
87+
"display_name": "Python 3.9.0 ('s2fft')",
88+
"language": "python",
89+
"name": "python3"
90+
},
91+
"language_info": {
92+
"codemirror_mode": {
93+
"name": "ipython",
94+
"version": 3
95+
},
96+
"file_extension": ".py",
97+
"mimetype": "text/x-python",
98+
"name": "python",
99+
"nbconvert_exporter": "python",
100+
"pygments_lexer": "ipython3",
101+
"version": "3.9.0"
102+
},
103+
"orig_nbformat": 4,
104+
"vscode": {
105+
"interpreter": {
106+
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
107+
}
108+
}
109+
},
110+
"nbformat": 4,
111+
"nbformat_minor": 2
112+
}

s2fft/recursions/price_mcewen.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def generate_precomputes_jax(
122122
nside: int = None,
123123
forward: bool = False,
124124
L_lower: int = 0,
125+
betas: jnp.ndarray = None,
125126
) -> List[jnp.ndarray]:
126127
r"""Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
127128
In practice one could compute these on-the-fly but the memory overhead is
@@ -145,17 +146,22 @@ def generate_precomputes_jax(
145146
L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
146147
for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.
147148
149+
beta (jnp.ndarray): Array of polar angles in radians.
150+
148151
Returns:
149152
List[jnp.ndarray]: List of precomputed coefficient arrays.
150153
"""
151154
mm = -spin
152155
L0 = L_lower
153156
# Correct for mw to mwss conversion
154-
if forward and sampling.lower() in ["mw", "mwss"]:
155-
sampling = "mwss"
156-
beta = samples.thetas(2 * L, "mwss")[1:-1]
157+
if betas is None:
158+
if forward and sampling.lower() in ["mw", "mwss"]:
159+
sampling = "mwss"
160+
beta = samples.thetas(2 * L, "mwss")[1:-1]
161+
else:
162+
beta = samples.thetas(L, sampling, nside)
157163
else:
158-
beta = samples.thetas(L, sampling, nside)
164+
beta = betas
159165

160166
ntheta = len(beta) # Number of theta samples
161167
el = jnp.arange(L0, L)
@@ -243,7 +249,7 @@ def renorm_m_loop(i, args):
243249

244250
# Remove redundant nans:
245251
# - in forward pass these are not accessed, so are irrelevant.
246-
# - in backward pass the adjoint computation otherwise accumulates these
252+
# - in backward pass the adjoint computation otherwise accumulates these
247253
# nans into grads if not set to zero.
248254
lrenorm = jnp.nan_to_num(lrenorm, nan=0.0, posinf=0.0, neginf=0.0)
249255
cpi = jnp.nan_to_num(cpi, nan=0.0, posinf=0.0, neginf=0.0)

s2fft/transforms/otf_recursions.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import jax.numpy as jnp
33
import jax.lax as lax
4-
from jax import jit, pmap, local_device_count
4+
from jax import jit, pmap, local_device_count, custom_vjp
55
from functools import partial
66
from typing import List
77
from s2fft.sampling import s2_samples as samples
@@ -225,7 +225,8 @@ def inverse_latitudinal_step_jax(
225225

226226
mm = -spin # switch to match convention
227227
ntheta = len(beta) # Number of theta samples
228-
ftm = jnp.zeros(samples.ftm_shape(L, sampling, nside), dtype=jnp.complex128)
228+
m_count = 2 * L if sampling.lower() in ["mwss", "healpix"] else 2 * L - 1
229+
ftm = jnp.zeros((ntheta, m_count), dtype=jnp.complex128)
229230
el = jnp.arange(L_lower, L)
230231

231232
# Trigonometric constant adopted throughout
@@ -239,7 +240,7 @@ def inverse_latitudinal_step_jax(
239240

240241
if precomps is None:
241242
precomps = generate_precomputes_jax(
242-
L, -mm, sampling, nside, L_lower=L_lower
243+
L, -mm, sampling, nside, L_lower=L_lower, betas=beta
243244
)
244245
lrenorm, vsign, cpi, cp2, indices = precomps
245246

@@ -401,6 +402,25 @@ def eval_recursion_step(
401402
pm_recursion_step,
402403
(ftm, dl_entry, dl_iter, lrenorm, indices, omc, c, s),
403404
)
405+
406+
# Remove south pole singularity
407+
m_offset = 1 if sampling.lower() in ["mwss", "healpix"] else 0
408+
if sampling.lower() in ["mw", "mwss"]:
409+
ftm = ftm.at[-1].set(0)
410+
ftm = ftm.at[-1, L - 1 + spin + m_offset].set(
411+
jnp.nansum(
412+
(-1) ** abs(jnp.arange(L_lower, L) - spin)
413+
* flm[L_lower:, L - 1 + spin]
414+
)
415+
)
416+
417+
# Remove north pole singularity
418+
if sampling.lower() == "mwss":
419+
ftm = ftm.at[0].set(0)
420+
ftm = ftm.at[0, L - 1 - spin + m_offset].set(
421+
jnp.nansum(flm[L_lower:, L - 1 - spin])
422+
)
423+
404424
return ftm
405425

406426

@@ -562,8 +582,8 @@ def forward_latitudinal_step(
562582

563583
@partial(jit, static_argnums=(2, 4, 5, 6, 8, 9))
564584
def forward_latitudinal_step_jax(
565-
ftm: jnp.ndarray,
566-
beta: jnp.ndarray,
585+
ftm_in: jnp.ndarray,
586+
beta_in: jnp.ndarray,
567587
L: int,
568588
spin: int,
569589
nside: int,
@@ -622,6 +642,16 @@ def forward_latitudinal_step_jax(
622642
between devices is noticable, however as L increases one will asymptotically
623643
recover acceleration by the number of devices.
624644
"""
645+
# Avoid pole-singularities for MWSS sampling
646+
if sampling.lower() == "mwss":
647+
ftm = ftm_in[1:-1]
648+
beta = beta_in[1:-1]
649+
elif sampling.lower() == "mw":
650+
ftm = ftm_in[:-1]
651+
beta = beta_in[:-1]
652+
else:
653+
ftm = ftm_in
654+
beta = beta_in
625655

626656
mm = -spin # switch to match convention
627657
ntheta = len(beta) # Number of theta samples
@@ -639,7 +669,7 @@ def forward_latitudinal_step_jax(
639669

640670
if precomps is None:
641671
precomps = generate_precomputes_jax(
642-
L, -mm, sampling, nside, True, L_lower
672+
L, -mm, sampling, nside, True, L_lower, betas=beta
643673
)
644674
lrenorm, vsign, cpi, cp2, indices = precomps
645675

@@ -840,4 +870,17 @@ def eval_recursion_step(
840870
),
841871
)[0]
842872
)
873+
874+
# Include both pole singularities explicitly
875+
m_offset = 1 if sampling.lower() in ["mwss", "healpix"] else 0
876+
if sampling.lower() in ["mw", "mwss"]:
877+
flm = flm.at[L_lower:, L - 1 + spin].add(
878+
(-1) ** abs(jnp.arange(L_lower, L) - spin)
879+
* ftm_in[-1, L - 1 + spin + m_offset]
880+
)
881+
882+
if sampling.lower() == "mwss":
883+
flm = flm.at[L_lower:, L - 1 - spin].add(
884+
ftm_in[0, L - 1 - spin + m_offset]
885+
)
843886
return flm

0 commit comments

Comments
 (0)