Skip to content

Commit 0d4698f

Browse files
committed
Explicitly cast kernels in einsum ops to avoid ComplexWarning causing test fails
1 parent aff48ac commit 0d4698f

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

s2fft/precompute_transforms/spherical.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def inverse_transform_jax(
193193
ftm = ftm.at[:, m_start_ind + m_offset :].add(
194194
jnp.einsum(
195195
"...tlm, ...lm -> ...tm",
196-
kernel,
196+
kernel.astype(ftm.dtype),
197197
flm[:, m_start_ind:],
198198
optimize=True,
199199
)
@@ -442,7 +442,9 @@ def forward_transform_jax(
442442

443443
flm = jnp.zeros(samples.flm_shape(L), dtype=jnp.complex128)
444444
flm = flm.at[:, m_start_ind:].set(
445-
jnp.einsum("...tlm, ...tm -> ...lm", kernel, ftm, optimize=True)
445+
jnp.einsum(
446+
"...tlm, ...tm -> ...lm", kernel.astype(flm.dtype), ftm, optimize=True
447+
)
446448
)
447449

448450
if reality:

s2fft/precompute_transforms/wigner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def inverse_transform_jax(
181181
fnab = fnab.at[n_start_ind:, :, m_offset:].set(
182182
jnp.einsum(
183183
"...ntlm, ...nlm -> ...ntm",
184-
kernel,
184+
kernel.astype(fnab.dtype),
185185
flmn[n_start_ind:, :, :],
186186
optimize=True,
187187
)
@@ -439,7 +439,9 @@ def forward_transform_jax(
439439

440440
flmn = jnp.zeros(samples.flmn_shape(L, N), dtype=jnp.complex128)
441441
flmn = flmn.at[n_start_ind:].set(
442-
jnp.einsum("...ntlm, ...ntm -> ...nlm", kernel, fban, optimize=True)
442+
jnp.einsum(
443+
"...ntlm, ...ntm -> ...nlm", kernel.astype(flmn.dtype), fban, optimize=True
444+
)
443445
)
444446
if reality:
445447
flmn = flmn.at[:n_start_ind].set(

tests/test_spherical_precompute.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
@pytest.mark.parametrize("reality", reality_to_test)
3333
@pytest.mark.parametrize("method", methods_to_test)
3434
@pytest.mark.parametrize("recursion", recursions_to_test)
35-
@pytest.mark.filterwarnings("ignore:Casting complex values")
3635
def test_transform_inverse(
3736
flm_generator,
3837
L: int,
@@ -160,7 +159,6 @@ def test_transform_inverse_healpix(
160159
@pytest.mark.parametrize("reality", reality_to_test)
161160
@pytest.mark.parametrize("method", methods_to_test)
162161
@pytest.mark.parametrize("recursion", recursions_to_test)
163-
@pytest.mark.filterwarnings("ignore:Casting complex values")
164162
def test_transform_forward(
165163
flm_generator,
166164
L: int,

0 commit comments

Comments
 (0)