Skip to content

Commit a7e887e

Browse files
authored
Merge pull request #202 from astro-informatics/feature/notebook_plots
Feature/notebook plots
2 parents bc7cbd8 + 0ff8fb4 commit a7e887e

File tree

7 files changed

+98
-62
lines changed

7 files changed

+98
-62
lines changed

notebooks/JAX_HEALPix_frontend.ipynb

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"# Install s2fft\n",
20-
"!pip install s2fft &> /dev/null"
19+
"import sys\n",
20+
"IN_COLAB = 'google.colab' in sys.modules\n",
21+
"\n",
22+
"# Install s2fft and data if running on google colab.\n",
23+
"if IN_COLAB:\n",
24+
" !pip install s2fft &> /dev/null"
2125
]
2226
},
2327
{
@@ -42,11 +46,12 @@
4246
"import numpy as np\n",
4347
"import s2fft \n",
4448
"\n",
45-
"L = 1024\n",
46-
"nside = 512\n",
49+
"L = 128\n",
50+
"nside = 64\n",
4751
"method = \"jax_healpy\"\n",
4852
"sampling = \"healpix\"\n",
49-
"flm = np.random.randn(L, 2*L-1) + 1j*np.random.randn(L, 2*L-1)\n",
53+
"rng = np.random.default_rng(23457801234570)\n",
54+
"flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n",
5055
"f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)"
5156
]
5257
},
@@ -183,7 +188,7 @@
183188
"name": "python",
184189
"nbconvert_exporter": "python",
185190
"pygments_lexer": "ipython3",
186-
"version": "3.11.8"
191+
"version": "3.10.0"
187192
},
188193
"orig_nbformat": 4,
189194
"vscode": {

notebooks/JAX_SSHT_frontend.ipynb

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"# Install s2fft\n",
20-
"!pip install s2fft &> /dev/null"
19+
"import sys\n",
20+
"IN_COLAB = 'google.colab' in sys.modules\n",
21+
"\n",
22+
"# Install s2fft and data if running on google colab.\n",
23+
"if IN_COLAB:\n",
24+
" !pip install s2fft &> /dev/null"
2125
]
2226
},
2327
{
@@ -42,9 +46,10 @@
4246
"import numpy as np\n",
4347
"import s2fft \n",
4448
"\n",
45-
"L = 1024\n",
49+
"L = 128\n",
4650
"method = \"jax_ssht\"\n",
47-
"flm = np.random.randn(L, 2*L-1) + 1j*np.random.randn(L, 2*L-1)\n",
51+
"rng = np.random.default_rng(23457801234570)\n",
52+
"flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n",
4853
"f = s2fft.inverse(flm, L, method=method)"
4954
]
5055
},
@@ -107,7 +112,7 @@
107112
"name": "stdout",
108113
"output_type": "stream",
109114
"text": [
110-
"Mean absolute error = 4.909423754134027e-11\n"
115+
"Mean absolute error = 7.784372519411174e-13\n"
111116
]
112117
}
113118
],
@@ -181,7 +186,7 @@
181186
"name": "python",
182187
"nbconvert_exporter": "python",
183188
"pygments_lexer": "ipython3",
184-
"version": "3.11.8"
189+
"version": "3.10.0"
185190
},
186191
"orig_nbformat": 4,
187192
"vscode": {

notebooks/spherical_harmonic_transform.ipynb

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,21 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 1,
15+
"execution_count": null,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"# Install s2fft\n",
20-
"!pip install s2fft &> /dev/null"
19+
"import sys\n",
20+
"IN_COLAB = 'google.colab' in sys.modules\n",
21+
"\n",
22+
"# Install a spherical plotting package.\n",
23+
"!pip install cartopy &> /dev/null\n",
24+
"\n",
25+
"# Install s2fft and data if running on google colab.\n",
26+
"if IN_COLAB:\n",
27+
" !pip install s2fft &> /dev/null\n",
28+
" !mkdir data/\n",
29+
" !wget https://github.com/astro-informatics/s2fft/raw/main/notebooks/data/Gaia_EDR3_flux.npy -P data/ &> /dev/null"
2130
]
2231
},
2332
{
@@ -32,20 +41,41 @@
3241
},
3342
{
3443
"cell_type": "code",
35-
"execution_count": 2,
44+
"execution_count": null,
3645
"metadata": {},
3746
"outputs": [],
3847
"source": [
3948
"import jax\n",
4049
"jax.config.update(\"jax_enable_x64\", True)\n",
4150
"\n",
4251
"import numpy as np\n",
52+
"from matplotlib import pyplot as plt \n",
53+
"import cartopy.crs as ccrs \n",
4354
"import s2fft \n",
4455
"\n",
45-
"L = 256\n",
4656
"sampling = \"mw\"\n",
47-
"flm = np.random.randn(L, 2*L-1) + 1j*np.random.randn(L, 2*L-1)\n",
48-
"f = s2fft.inverse_jax(flm, L)"
57+
"f = np.load('data/Gaia_EDR3_flux.npy')\n",
58+
"L = f.shape[0]"
59+
]
60+
},
61+
{
62+
"cell_type": "markdown",
63+
"metadata": {},
64+
"source": [
65+
"Lets look at the input signal"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"plt.figure(figsize=(10,5))\n",
75+
"ax = plt.axes(projection=ccrs.Mollweide())\n",
76+
"im = ax.imshow(f, transform=ccrs.PlateCarree(), cmap='magma')\n",
77+
"plt.axis('off')\n",
78+
"plt.show()"
4979
]
5080
},
5181
{
@@ -62,7 +92,7 @@
6292
},
6393
{
6494
"cell_type": "code",
65-
"execution_count": 3,
95+
"execution_count": null,
6696
"metadata": {},
6797
"outputs": [],
6898
"source": [
@@ -81,7 +111,7 @@
81111
},
82112
{
83113
"cell_type": "code",
84-
"execution_count": 4,
114+
"execution_count": null,
85115
"metadata": {},
86116
"outputs": [],
87117
"source": [
@@ -103,7 +133,7 @@
103133
},
104134
{
105135
"cell_type": "code",
106-
"execution_count": 5,
136+
"execution_count": null,
107137
"metadata": {},
108138
"outputs": [],
109139
"source": [
@@ -122,7 +152,7 @@
122152
},
123153
{
124154
"cell_type": "code",
125-
"execution_count": 6,
155+
"execution_count": null,
126156
"metadata": {},
127157
"outputs": [],
128158
"source": [
@@ -144,34 +174,18 @@
144174
},
145175
{
146176
"cell_type": "code",
147-
"execution_count": 7,
148-
"metadata": {},
149-
"outputs": [
150-
{
151-
"name": "stdout",
152-
"output_type": "stream",
153-
"text": [
154-
"Mean absolute error = 8.478196507592078e-11\n"
155-
]
156-
}
157-
],
177+
"execution_count": null,
178+
"metadata": {},
179+
"outputs": [],
158180
"source": [
159181
"print(f\"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}\")"
160182
]
161183
},
162184
{
163185
"cell_type": "code",
164-
"execution_count": 8,
165-
"metadata": {},
166-
"outputs": [
167-
{
168-
"name": "stdout",
169-
"output_type": "stream",
170-
"text": [
171-
"Mean absolute error using precomputes = 8.478196507592078e-11\n"
172-
]
173-
}
174-
],
186+
"execution_count": null,
187+
"metadata": {},
188+
"outputs": [],
175189
"source": [
176190
"print(f\"Mean absolute error using precomputes = {np.nanmean(np.abs(f_recov_pre - f))}\")"
177191
]
@@ -193,7 +207,7 @@
193207
"name": "python",
194208
"nbconvert_exporter": "python",
195209
"pygments_lexer": "ipython3",
196-
"version": "3.11.8"
210+
"version": "3.10.0"
197211
},
198212
"orig_nbformat": 4,
199213
"vscode": {

notebooks/spherical_rotation.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"# Install s2fft\n",
20-
"!pip install s2fft &> /dev/null"
19+
"import sys\n",
20+
"IN_COLAB = 'google.colab' in sys.modules\n",
21+
"\n",
22+
"# Install s2fft and data if running on google colab.\n",
23+
"if IN_COLAB:\n",
24+
" !pip install s2fft &> /dev/null"
2125
]
2226
},
2327
{

notebooks/torch_frontend.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"# Install s2fft\n",
20-
"!pip install s2fft &> /dev/null"
19+
"import sys\n",
20+
"IN_COLAB = 'google.colab' in sys.modules\n",
21+
"\n",
22+
"# Install s2fft and data if running on google colab.\n",
23+
"if IN_COLAB:\n",
24+
" !pip install s2fft &> /dev/null"
2125
]
2226
},
2327
{

notebooks/wigner_transform.ipynb

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"# Install s2fft\n",
20-
"!pip install s2fft &> /dev/null"
19+
"import sys\n",
20+
"IN_COLAB = 'google.colab' in sys.modules\n",
21+
"\n",
22+
"# Install s2fft and data if running on google colab.\n",
23+
"if IN_COLAB:\n",
24+
" !pip install s2fft &> /dev/null"
2125
]
2226
},
2327
{
@@ -47,7 +51,7 @@
4751
"L = 128\n",
4852
"N = 3\n",
4953
"reality = True\n",
50-
"rng = np.random.default_rng(0)\n",
54+
"rng = np.random.default_rng(83459)\n",
5155
"flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)"
5256
]
5357
},
@@ -190,7 +194,7 @@
190194
"name": "python",
191195
"nbconvert_exporter": "python",
192196
"pygments_lexer": "ipython3",
193-
"version": "3.9.13"
197+
"version": "3.10.0"
194198
},
195199
"orig_nbformat": 4,
196200
"vscode": {

s2fft/utils/signal_generator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ def generate_flm(
3838

3939
for el in range(max(L_lower, abs(spin)), L):
4040
if reality:
41-
flm[el, 0 + L - 1] = rng.uniform()
41+
flm[el, 0 + L - 1] = rng.normal()
4242
else:
43-
flm[el, 0 + L - 1] = rng.uniform() + 1j * rng.uniform()
43+
flm[el, 0 + L - 1] = rng.normal() + 1j * rng.normal()
4444

4545
for m in range(1, el + 1):
46-
flm[el, m + L - 1] = rng.uniform() + 1j * rng.uniform()
46+
flm[el, m + L - 1] = rng.normal() + 1j * rng.normal()
4747
if reality:
4848
flm[el, -m + L - 1] = (-1) ** m * np.conj(flm[el, m + L - 1])
4949
else:
50-
flm[el, -m + L - 1] = rng.uniform() + 1j * rng.uniform()
50+
flm[el, -m + L - 1] = rng.normal() + 1j * rng.normal()
5151

5252
return torch.from_numpy(flm) if using_torch else flm
5353

@@ -86,22 +86,22 @@ def generate_flmn(
8686
for n in range(-N + 1, N):
8787
for el in range(max(L_lower, abs(n)), L):
8888
if reality:
89-
flmn[N - 1 + n, el, 0 + L - 1] = rng.uniform()
89+
flmn[N - 1 + n, el, 0 + L - 1] = rng.normal()
9090
flmn[N - 1 - n, el, 0 + L - 1] = (-1) ** n * flmn[
9191
N - 1 + n,
9292
el,
9393
0 + L - 1,
9494
]
9595
else:
96-
flmn[N - 1 + n, el, 0 + L - 1] = rng.uniform() + 1j * rng.uniform()
96+
flmn[N - 1 + n, el, 0 + L - 1] = rng.normal() + 1j * rng.normal()
9797

9898
for m in range(1, el + 1):
99-
flmn[N - 1 + n, el, m + L - 1] = rng.uniform() + 1j * rng.uniform()
99+
flmn[N - 1 + n, el, m + L - 1] = rng.normal() + 1j * rng.normal()
100100
if reality:
101101
flmn[N - 1 - n, el, -m + L - 1] = (-1) ** (m + n) * np.conj(
102102
flmn[N - 1 + n, el, m + L - 1]
103103
)
104104
else:
105-
flmn[N - 1 + n, el, -m + L - 1] = rng.uniform() + 1j * rng.uniform()
105+
flmn[N - 1 + n, el, -m + L - 1] = rng.normal() + 1j * rng.normal()
106106

107107
return torch.from_numpy(flmn) if using_torch else flmn

0 commit comments

Comments
 (0)