Skip to content

Commit f50eaee

Browse files
committed
Only pass precomps to forward transform when iterating
1 parent a2abd7c commit f50eaee

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

s2fft/transforms/spherical.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -421,20 +421,18 @@ def forward(
421421
"nside": nside,
422422
"sampling": sampling,
423423
"reality": reality,
424-
"precomps": precomps,
425424
"L_lower": L_lower,
426425
}
426+
forward_kwargs = {**common_kwargs, "precomps": precomps}
427+
inverse_kwargs = common_kwargs
427428
if method in {"jax", "cuda"}:
428-
forward_kwargs = {
429-
**common_kwargs,
430-
"spmd": spmd,
431-
"use_healpix_custom_primitive": method == "cuda",
432-
}
433-
inverse_kwargs = {**common_kwargs, "method": "jax"}
429+
forward_kwargs["spmd"] = spmd
430+
forward_kwargs["use_healpix_custom_primitive"] = method == "cuda"
431+
inverse_kwargs["method"] = "jax"
432+
inverse_kwargs["spmd"] = spmd
434433
forward_function = forward_jax
435434
else:
436-
forward_kwargs = common_kwargs
437-
inverse_kwargs = {**common_kwargs, "method": "numpy"}
435+
inverse_kwargs["method"] = "numpy"
438436
forward_function = forward_numpy
439437
return iterative_refinement.forward_with_iterative_refinement(
440438
f=f,

0 commit comments

Comments
 (0)