File tree Expand file tree Collapse file tree 1 file changed +7
-9
lines changed
Expand file tree Collapse file tree 1 file changed +7
-9
lines changed Original file line number Diff line number Diff line change @@ -421,20 +421,18 @@ def forward(
421421 "nside" : nside ,
422422 "sampling" : sampling ,
423423 "reality" : reality ,
424- "precomps" : precomps ,
425424 "L_lower" : L_lower ,
426425 }
426+ forward_kwargs = {** common_kwargs , "precomps" : precomps }
427+ inverse_kwargs = common_kwargs
427428 if method in {"jax" , "cuda" }:
428- forward_kwargs = {
429- ** common_kwargs ,
430- "spmd" : spmd ,
431- "use_healpix_custom_primitive" : method == "cuda" ,
432- }
433- inverse_kwargs = {** common_kwargs , "method" : "jax" }
429+ forward_kwargs ["spmd" ] = spmd
430+ forward_kwargs ["use_healpix_custom_primitive" ] = method == "cuda"
431+ inverse_kwargs ["method" ] = "jax"
432+ inverse_kwargs ["spmd" ] = spmd
434433 forward_function = forward_jax
435434 else :
436- forward_kwargs = common_kwargs
437- inverse_kwargs = {** common_kwargs , "method" : "numpy" }
435+ inverse_kwargs ["method" ] = "numpy"
438436 forward_function = forward_numpy
439437 return iterative_refinement .forward_with_iterative_refinement (
440438 f = f ,
You can’t perform that action at this time.
0 commit comments