Skip to content

Commit 292d268

Browse files
authored
Merge pull request #44 from astro-informatics/python_harmonic
Python harmonic
2 parents 853538d + b185675 commit 292d268

File tree

3 files changed

+260
-3
lines changed

3 files changed

+260
-3
lines changed

src/main/pys2let/pys2let.pyx

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ cdef extern from "s2let/s2let.h":
6262
int s2let_n_scal(const s2let_parameters_t *parameters)
6363
int s2let_n_wav(const s2let_parameters_t *parameters)
6464

65+
int s2let_n_lm_scal(const s2let_parameters_t *parameters)
66+
int s2let_n_lmn_wav(const s2let_parameters_t *parameters)
67+
6568
int s2let_n_wav_j(int j, const s2let_parameters_t *parameters)
6669

6770
void s2let_mw_alm2map(double complex * f, const double complex * flm, int L, int spin)
@@ -205,6 +208,39 @@ cdef extern from "s2let/s2let.h":
205208
double complex *f_scal,
206209
const double complex *f,
207210
const s2let_parameters_t *parameters)
211+
212+
void s2let_analysis_lm2lmn(
213+
double complex *f_wav_lmn,
214+
double complex *f_scal_lm,
215+
const double complex *flm,
216+
const double complex *wav_lm,
217+
const double *scal_l,
218+
const s2let_parameters_t *parameters)
219+
220+
void s2let_analysis_adjoint_lmn2lm(
221+
double complex *flm,
222+
const double complex *f_wav_lmn,
223+
const double complex *f_scal_lm,
224+
const double complex *wav_lm,
225+
const double *scal_l,
226+
const s2let_parameters_t *parameters)
227+
228+
void s2let_synthesis_lmn2lm(
229+
double complex *flm,
230+
const double complex *f_wav_lmn,
231+
const double complex *f_scal_lm,
232+
const double complex *wav_lm,
233+
const double *scal_l,
234+
const s2let_parameters_t *parameters)
235+
236+
void s2let_synthesis_adjoint_lm2lmn(
237+
double complex *f_wav_lmn,
238+
double complex *f_scal_lm,
239+
const double complex *flm,
240+
const double complex *wav_lm,
241+
const double *scal_l,
242+
const s2let_parameters_t *parameters)
243+
208244
#----------------------------------------------------------------------------------------------------#
209245

210246
cdef extern from "stdlib.h":
@@ -744,6 +780,157 @@ def synthesis_adjoint_px2wav(
744780

745781
#----------------------------------------------------------------------------------------------------#
746782

783+
def analysis_lm2lmn(np.ndarray[double complex, ndim=1, mode="c"] flm not None,
784+
B, L, J_min, N, spin, upsample, spin_lowered=False, original_spin=0):
785+
786+
cdef s2let_parameters_t parameters = {}
787+
parameters.B = B
788+
parameters.L = L
789+
parameters.J_min = J_min
790+
parameters.N = N
791+
parameters.spin = spin
792+
parameters.upsample = upsample
793+
parameters.sampling_scheme = S2LET_SAMPLING_MW
794+
parameters.original_spin = original_spin
795+
parameters.dl_method = SSHT_DL_RISBO
796+
parameters.reality = 0
797+
parameters.verbosity = 0
798+
J = s2let_j_max(&parameters)
799+
800+
scal_l = np.zeros([L,])
801+
wav_lm = np.zeros([(J + 1) * L * L,], dtype=complex)
802+
s2let_tiling_wavelet(<double complex*> np.PyArray_DATA(wav_lm),
803+
<double *> np.PyArray_DATA(scal_l),
804+
&parameters)
805+
806+
f_wav_lmn = np.zeros([s2let_n_lmn_wav(&parameters),], dtype=complex)
807+
f_scal_lm = np.zeros([s2let_n_lm_scal(&parameters),], dtype=complex)
808+
809+
s2let_analysis_lm2lmn(
810+
<double complex*> np.PyArray_DATA(f_wav_lmn),
811+
<double complex*> np.PyArray_DATA(f_scal_lm),
812+
<const double complex*> np.PyArray_DATA(flm),
813+
<const double complex*> np.PyArray_DATA(wav_lm),
814+
<const double *> np.PyArray_DATA(scal_l),
815+
&parameters)
816+
return f_wav_lmn, f_scal_lm
817+
818+
#----------------------------------------------------------------------------------------------------#
819+
820+
def analysis_adjoint_lmn2lm(
821+
np.ndarray[double complex, ndim=1, mode="c"] f_wav_lmn not None,
822+
np.ndarray[double complex, ndim=1, mode="c"] f_scal_lm not None,
823+
B, L, J_min, N, spin, upsample, spin_lowered=False, original_spin=0):
824+
825+
cdef s2let_parameters_t parameters = {}
826+
parameters.B = B
827+
parameters.L = L
828+
parameters.J_min = J_min
829+
parameters.N = N
830+
parameters.spin = spin
831+
parameters.upsample = upsample
832+
parameters.sampling_scheme = S2LET_SAMPLING_MW
833+
parameters.original_spin = original_spin
834+
parameters.dl_method = SSHT_DL_RISBO
835+
parameters.reality = 0
836+
parameters.verbosity = 0
837+
J = s2let_j_max(&parameters)
838+
839+
scal_l = np.zeros([L,])
840+
wav_lm = np.zeros([(J + 1) * L * L,], dtype=complex)
841+
s2let_tiling_wavelet(<double complex*> np.PyArray_DATA(wav_lm),
842+
<double *> np.PyArray_DATA(scal_l),
843+
&parameters)
844+
845+
flm = np.zeros([L * L,], dtype=complex)
846+
s2let_analysis_adjoint_lmn2lm(
847+
<double complex*> np.PyArray_DATA(flm),
848+
<const double complex*> np.PyArray_DATA(f_wav_lmn),
849+
<const double complex*> np.PyArray_DATA(f_scal_lm),
850+
<const double complex*> np.PyArray_DATA(wav_lm),
851+
<const double *> np.PyArray_DATA(scal_l),
852+
&parameters)
853+
854+
return flm
855+
856+
#----------------------------------------------------------------------------------------------------#
857+
858+
def synthesis_lmn2lm(
859+
np.ndarray[double complex, ndim=1, mode="c"] f_wav_lmn not None,
860+
np.ndarray[double complex, ndim=1, mode="c"] f_scal_lm not None,
861+
B, L, J_min, N, spin, upsample, spin_lowered=False, original_spin=0):
862+
863+
cdef s2let_parameters_t parameters = {}
864+
parameters.B = B
865+
parameters.L = L
866+
parameters.J_min = J_min
867+
parameters.N = N
868+
parameters.spin = spin
869+
parameters.upsample = upsample
870+
parameters.sampling_scheme = S2LET_SAMPLING_MW
871+
parameters.original_spin = original_spin
872+
parameters.dl_method = SSHT_DL_RISBO
873+
parameters.reality = 0
874+
parameters.verbosity = 0
875+
J = s2let_j_max(&parameters)
876+
877+
scal_l = np.zeros([L,])
878+
wav_lm = np.zeros([(J + 1) * L * L,], dtype=complex)
879+
s2let_tiling_wavelet(<double complex*> np.PyArray_DATA(wav_lm),
880+
<double *> np.PyArray_DATA(scal_l),
881+
&parameters)
882+
883+
flm = np.zeros([L * L,], dtype=complex)
884+
s2let_synthesis_lmn2lm(
885+
<double complex*> np.PyArray_DATA(flm),
886+
<const double complex*> np.PyArray_DATA(f_wav_lmn),
887+
<const double complex*> np.PyArray_DATA(f_scal_lm),
888+
<const double complex*> np.PyArray_DATA(wav_lm),
889+
<const double *> np.PyArray_DATA(scal_l),
890+
&parameters)
891+
892+
return flm
893+
894+
#----------------------------------------------------------------------------------------------------#
895+
896+
def synthesis_adjoint_lm2lmn(np.ndarray[double complex, ndim=1, mode="c"] flm not None,
897+
B, L, J_min, N, spin, upsample, spin_lowered=False, original_spin=0):
898+
899+
cdef s2let_parameters_t parameters = {}
900+
parameters.B = B
901+
parameters.L = L
902+
parameters.J_min = J_min
903+
parameters.N = N
904+
parameters.spin = spin
905+
parameters.upsample = upsample
906+
parameters.sampling_scheme = S2LET_SAMPLING_MW
907+
parameters.original_spin = original_spin
908+
parameters.dl_method = SSHT_DL_RISBO
909+
parameters.reality = 0
910+
parameters.verbosity = 0
911+
J = s2let_j_max(&parameters)
912+
913+
scal_l = np.zeros([L,])
914+
wav_lm = np.zeros([(J + 1) * L * L,], dtype=complex)
915+
s2let_tiling_wavelet(<double complex*> np.PyArray_DATA(wav_lm),
916+
<double *> np.PyArray_DATA(scal_l),
917+
&parameters)
918+
919+
f_wav_lmn = np.zeros([s2let_n_lmn_wav(&parameters),], dtype=complex)
920+
f_scal_lm = np.zeros([s2let_n_lm_scal(&parameters),], dtype=complex)
921+
922+
s2let_synthesis_adjoint_lm2lmn(
923+
<double complex*> np.PyArray_DATA(f_wav_lmn),
924+
<double complex*> np.PyArray_DATA(f_scal_lm),
925+
<const double complex*> np.PyArray_DATA(flm),
926+
<const double complex*> np.PyArray_DATA(wav_lm),
927+
<const double *> np.PyArray_DATA(scal_l),
928+
&parameters)
929+
930+
return f_wav_lmn, f_scal_lm
931+
932+
#----------------------------------------------------------------------------------------------------#
933+
747934
def mw_size(L):
748935
return L*(2*L-1)
749936

src/test/python/test_axisym_adjoints.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
synthesis_adjoint_px2wav,
1313
synthesis_axisym_wav_mw,
1414
synthesis_wav2px,
15+
pys2let_j_max,
1516
)
1617

1718

@@ -61,8 +62,6 @@ def random_wavlet_maps(rng, L, spin, nwvlts):
6162
def test_axisym_adjoint(
6263
px2wav, wav2px, spin, rng: np.random.Generator, L=10, B=2, J_min=2
6364
):
64-
from pys2let import pys2let_j_max
65-
6665
nwvlts = pys2let_j_max(B, L, J_min) - J_min + 1
6766

6867
x = random_mw_map(rng, L, spin)
@@ -71,4 +70,9 @@ def test_axisym_adjoint(
7170
y = wav2px(y_wav, y_scal, B, L, J_min)
7271
x_wav, x_scal = px2wav(x, B, L, J_min)
7372

74-
assert y_wav.conj().T @ x_wav + y_scal.conj() @ x_scal == approx(y.conj().T @ x)
73+
# y'Ax
74+
yAx = y_wav.conj().T @ x_wav + y_scal.conj() @ x_scal
75+
# (A'y)'x
76+
Ayx = approx(y.conj().T @ x)
77+
78+
assert yAx == Ayx
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from functools import partial
2+
3+
import numpy as np
4+
from pytest import approx, fixture, mark
5+
6+
from pys2let import (
7+
analysis_lm2lmn,
8+
analysis_adjoint_lmn2lm,
9+
synthesis_lmn2lm,
10+
synthesis_adjoint_lm2lmn,
11+
pys2let_j_max,
12+
)
13+
14+
15+
@fixture
16+
def rng(request):
17+
return np.random.default_rng(getattr(request.config.option, "randomly_seed", None))
18+
19+
20+
def random_lms(rng, L):
21+
return rng.uniform(size=(L * L, 2)) @ [1, 1j]
22+
23+
24+
def random_wavlet_lms(rng, L, nwvlts):
25+
lms = [random_lms(rng, L) for _ in range(nwvlts + 1)]
26+
return lms[0], np.concatenate(lms[1:])
27+
28+
29+
@mark.parametrize(
30+
"px2wav,wav2px",
31+
[
32+
(
33+
partial(analysis_lm2lmn, spin=0, upsample=1, N=1),
34+
partial(analysis_adjoint_lmn2lm, spin=0, upsample=1, N=1),
35+
),
36+
(
37+
partial(analysis_lm2lmn, spin=2, upsample=1, N=1),
38+
partial(analysis_adjoint_lmn2lm, spin=2, upsample=1, N=1),
39+
),
40+
(
41+
partial(synthesis_adjoint_lm2lmn, spin=0, upsample=1, N=1),
42+
partial(synthesis_lmn2lm, spin=0, upsample=1, N=1),
43+
),
44+
(
45+
partial(synthesis_adjoint_lm2lmn, spin=2, upsample=1, N=1),
46+
partial(synthesis_lmn2lm, spin=2, upsample=1, N=1),
47+
),
48+
],
49+
)
50+
def test_axisym_adjoint(
51+
px2wav, wav2px, rng: np.random.Generator, L=10, B=2, J_min=2
52+
):
53+
nwvlts = pys2let_j_max(B, L, J_min) - J_min + 1
54+
55+
x = random_lms(rng, L)
56+
y_scal, y_wav = random_wavlet_lms(rng, L, nwvlts)
57+
58+
y = wav2px(y_wav, y_scal, B, L, J_min)
59+
x_wav, x_scal = px2wav(x, B, L, J_min)
60+
61+
# y'Ax
62+
yAx = y_wav.conj().T @ x_wav + y_scal.conj() @ x_scal
63+
# (A'y)'x
64+
Ayx = approx(y.conj().T @ x)
65+
66+
assert yAx == Ayx

0 commit comments

Comments
 (0)