Skip to content

Commit 480696a

Browse files
mreineckmdavezac
authored andcommitted
Feature: Support negative spins in ducc backend
1 parent fbeee01 commit 480696a

File tree

3 files changed

+76
-52
lines changed

3 files changed

+76
-52
lines changed

src/pyssht/ducc_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_rot(L, nthreads=1):
8282
for Method in ["MW", "MWSS", "GL", "DH"]:
8383
for L in L_list:
8484
for Reality in [False, True]:
85-
for Spin in [0] if Reality else [0, 1]:
85+
for Spin in [0] if Reality else [-1, 0, 1]:
8686
res = test_SHT(L, Method, Reality, Spin, nthreads=nthreads)
8787
print(
8888
"{:4}, L={:4}, Reality={:5}, Spin={:1}: L2 error={:e}, speedup factor={:f}".format(

src/pyssht/ducc_interface.pyx

Lines changed: 72 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -57,50 +57,70 @@ def _build_real_flm(alm, Py_ssize_t L):
5757
ofs += L-m
5858
return res
5959

60-
cdef _extract_complex_alm(flm, Py_ssize_t L):
60+
cdef _extract_complex_alm(flm, Py_ssize_t L, Py_ssize_t Spin):
6161
res = np.empty((2, _nalm(L-1, L-1),), dtype=np.complex128)
6262
cdef Py_ssize_t ofs=0, m, i
63-
cdef double mfac
63+
cdef double mfac, sfac=(-1)**abs(Spin)
6464
cdef complex[:,:] myres=res
6565
cdef complex[:] myflm=flm
6666
cdef Py_ssize_t[:] lidx = _get_lidx(L)
6767
cdef complex fp, fm
68-
for m in range(L):
69-
mfac = (-1)**m
70-
for i in range(m,L):
71-
fp = myflm[lidx[i]+m]
72-
fm = mfac * (myflm[lidx[i]-m].real - 1j*myflm[lidx[i]-m].imag)
73-
myres[0, ofs-m+i] = 0.5*(fp+fm)
74-
myres[1, ofs-m+i] = -0.5j*(fp-fm)
75-
ofs += L-m
68+
if Spin >= 0:
69+
for m in range(L):
70+
mfac = (-1)**m
71+
for i in range(m,L):
72+
fp = myflm[lidx[i]+m]
73+
fm = mfac * (myflm[lidx[i]-m].real - 1j*myflm[lidx[i]-m].imag)
74+
myres[0, ofs-m+i] = 0.5*(fp+fm)
75+
myres[1, ofs-m+i] = -0.5j*(fp-fm)
76+
ofs += L-m
77+
else:
78+
for m in range(L):
79+
mfac = (-1)**m
80+
for i in range(m,L):
81+
fp = mfac*sfac*(myflm[lidx[i]-m].real - 1j*myflm[lidx[i]-m].imag)
82+
fm = sfac*myflm[lidx[i]+m]
83+
myres[0, ofs-m+i] = 0.5*(fp+fm)
84+
myres[1, ofs-m+i] = -0.5j*(fp-fm)
85+
ofs += L-m
7686
return res
7787

78-
cdef _build_complex_flm(alm, Py_ssize_t L):
88+
cdef _build_complex_flm(alm, Py_ssize_t L, Py_ssize_t Spin):
7989
res = np.empty((L*L), dtype=np.complex128)
8090
cdef Py_ssize_t ofs=0, m, i
8191
cdef complex fp, fm
8292
cdef complex[:] myres=res
8393
cdef complex[:,:] myalm=alm
8494
cdef Py_ssize_t[:] lidx = _get_lidx(L)
85-
cdef double mfac
86-
for m in range(L):
87-
mfac = (-1)**m
88-
for i in range(m,L):
89-
fp = myalm[0, ofs-m+i] + 1j*myalm[1, ofs-m+i]
90-
fm = myalm[0, ofs-m+i] - 1j*myalm[1, ofs-m+i]
91-
myres[lidx[i]+m] = fp
92-
myres[lidx[i]-m] = mfac*(fm.real - 1j*fm.imag)
93-
ofs += L-m
95+
cdef double mfac, sfac=(-1)**abs(Spin)
96+
if Spin >= 0:
97+
for m in range(L):
98+
mfac = (-1)**m
99+
for i in range(m,L):
100+
fp = myalm[0, ofs-m+i] + 1j*myalm[1, ofs-m+i]
101+
fm = myalm[0, ofs-m+i] - 1j*myalm[1, ofs-m+i]
102+
myres[lidx[i]+m] = fp
103+
myres[lidx[i]-m] = mfac*(fm.real - 1j*fm.imag)
104+
ofs += L-m
105+
else:
106+
for m in range(L):
107+
mfac = (-1)**m
108+
for i in range(m,L):
109+
fp = myalm[0, ofs-m+i] + 1j*myalm[1, ofs-m+i]
110+
fm = myalm[0, ofs-m+i] - 1j*myalm[1, ofs-m+i]
111+
myres[lidx[i]+m] = sfac*fm
112+
myres[lidx[i]-m] = sfac*mfac*(fp.real -1j *fp.imag)
113+
ofs += L-m
94114
return res
95115

96116

97117
def rotate_flms(flm, alpha, beta, gamma, L, int nthreads = 1):
98118
ducc0 = import_ducc0()
99-
alm = _extract_complex_alm(flm, L)
119+
alm = _extract_complex_alm(flm, L, 0)
100120
for i in range(2):
101121
alm[i] = ducc0.sht.rotate_alm(
102122
alm[i], L-1, gamma, beta, alpha, nthreads=nthreads)
103-
return _build_complex_flm(alm, L)
123+
return _build_complex_flm(alm, L, 0)
104124

105125

106126
def inverse(np.ndarray flm, Py_ssize_t L, Py_ssize_t Spin, str Method, bint Reality, int nthreads = 1):
@@ -121,20 +141,20 @@ def inverse(np.ndarray flm, Py_ssize_t L, Py_ssize_t Spin, str Method, bint Real
121141
spin=0,
122142
geometry=gdict[Method])[0]
123143
elif Spin == 0:
124-
alm = _extract_complex_alm(flm, L)
144+
alm = _extract_complex_alm(flm, L,0)
125145
flmr = _build_real_flm(alm[0], L)
126146
flmi = _build_real_flm(alm[1], L)
127147
return inverse(flmr, L, 0, Method, True) + 1j*inverse(flmi, L, 0, Method, True)
128148
else:
129149
tmp=ducc0.sht.experimental.synthesis_2d(
130-
alm=_extract_complex_alm(flm, L),
150+
alm=_extract_complex_alm(flm, L, Spin),
131151
ntheta=ntheta,
132152
nphi=nphi,
133153
lmax=L-1,
134154
nthreads=nthreads,
135-
spin=Spin,
155+
spin=abs(Spin),
136156
geometry=gdict[Method])
137-
res = -1j*tmp[1]
157+
res = -1j*tmp[1] if Spin >=0 else 1j*tmp[1]
138158
res -= tmp[0]
139159
return res
140160

@@ -155,22 +175,24 @@ def inverse_adjoint(f, L, Spin, Method, Reality, int nthreads = 1):
155175
spin=0,
156176
geometry=gdict[Method])[0], L)
157177
elif Spin == 0:
158-
flmr = inverse_adjoint(f.real, L, Spin, Method, True)
159-
flmi = inverse_adjoint(f.imag, L, Spin, Method, True)
160-
alm = np.empty((2,_nalm(L-1, L-1)), dtype=np.complex128)
161-
alm[0] = _extract_real_alm(flmr, L)
162-
alm[1] = _extract_real_alm(flmi, L)
163-
return _build_complex_flm(alm, L)
178+
flmr = inverse_adjoint(f.real, L, Spin, Method, True)
179+
flmi = inverse_adjoint(f.imag, L, Spin, Method, True)
180+
alm = np.empty((2,_nalm(L-1, L-1)), dtype=np.complex128)
181+
alm[0] = _extract_real_alm(flmr, L)
182+
alm[1] = _extract_real_alm(flmi, L)
183+
return _build_complex_flm(alm, L, 0)
164184
else:
165-
map = f.astype(np.complex128).view(dtype=np.float64).reshape((f.shape[0],f.shape[1],2)).transpose((2,0,1))
166-
res = _build_complex_flm(ducc0.sht.experimental.adjoint_synthesis_2d(
167-
map=map,
168-
lmax=L-1,
169-
nthreads=nthreads,
170-
spin=Spin,
171-
geometry=gdict[Method]), L)
172-
res *= -1
173-
return res
185+
map = f.astype(np.complex128).view(dtype=np.float64).reshape((f.shape[0],f.shape[1],2)).transpose((2,0,1))
186+
if Spin < 0:
187+
map[1]*=-1
188+
res = _build_complex_flm(ducc0.sht.experimental.adjoint_synthesis_2d(
189+
map=map,
190+
lmax=L-1,
191+
nthreads=nthreads,
192+
spin=abs(Spin),
193+
geometry=gdict[Method]), L, Spin)
194+
res *= -1
195+
return res
174196

175197

176198
def forward(f, L, Spin, Method, Reality, int nthreads = 1):
@@ -194,15 +216,17 @@ def forward(f, L, Spin, Method, Reality, int nthreads = 1):
194216
alm = np.empty((2,_nalm(L-1, L-1)), dtype=np.complex128)
195217
alm[0] = _extract_real_alm(flmr, L)
196218
alm[1] = _extract_real_alm(flmi, L)
197-
return _build_complex_flm(alm, L)
219+
return _build_complex_flm(alm, L, 0)
198220
else:
199221
map = f.astype(np.complex128).view(dtype=np.float64).reshape((f.shape[0],f.shape[1],2)).transpose((2,0,1))
222+
if Spin < 0:
223+
map[1]*=-1
200224
res = _build_complex_flm(ducc0.sht.experimental.analysis_2d(
201225
map=map,
202226
lmax=L-1,
203227
nthreads=nthreads,
204-
spin=Spin,
205-
geometry=gdict[Method]), L)
228+
spin=abs(Spin),
229+
geometry=gdict[Method]), L, Spin)
206230
res *= -1
207231
return res
208232

@@ -225,19 +249,19 @@ def forward_adjoint(np.ndarray flm, Py_ssize_t L, Py_ssize_t Spin, str Method, b
225249
spin=0,
226250
geometry=gdict[Method])[0]
227251
elif Spin == 0:
228-
alm = _extract_complex_alm(flm, L)
252+
alm = _extract_complex_alm(flm, L, 0)
229253
flmr = _build_real_flm(alm[0], L)
230254
flmi = _build_real_flm(alm[1], L)
231255
return forward_adjoint(flmr, L, 0, Method, True) + 1j*forward_adjoint(flmi, L, 0, Method, True)
232256
else:
233257
tmp=ducc0.sht.experimental.adjoint_analysis_2d(
234-
alm=_extract_complex_alm(flm, L),
258+
alm=_extract_complex_alm(flm, L, Spin),
235259
ntheta=ntheta,
236260
nphi=nphi,
237261
lmax=L-1,
238262
nthreads=nthreads,
239-
spin=Spin,
263+
spin=abs(Spin),
240264
geometry=gdict[Method])
241-
res = -1j*tmp[1]
265+
res = -1j*tmp[1] if Spin >=0 else 1j*tmp[1]
242266
res -= tmp[0]
243267
return res

tests/test_ducc0.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def method(request):
2121
return request.param
2222

2323

24-
@fixture(params=[0, 1, 2])
24+
@fixture(params=[-2, -1, 0, 1, 2])
2525
def spin(request):
2626
return request.param
2727

@@ -139,7 +139,7 @@ def test_complex_inverse_adjoint_ssht_vs_ducc0(
139139

140140
try:
141141
ssht_adj_coeffs = ssht.inverse_adjoint(
142-
complex_image, order, Reality=False, Method=method, Spin=0
142+
complex_image, order, Reality=False, Method=method, Spin=spin
143143
)
144144
except ssht_input_error:
145145
assert method not in ("MW", "MWSS")
@@ -150,7 +150,7 @@ def test_complex_inverse_adjoint_ssht_vs_ducc0(
150150
order,
151151
Reality=False,
152152
Method=method,
153-
Spin=0,
153+
Spin=spin,
154154
backend="ducc",
155155
nthreads=nthreads,
156156
)

0 commit comments

Comments
 (0)