|
1 | 1 | import numpy as np |
2 | 2 | import jax.numpy as jnp |
3 | 3 | import jax.lax as lax |
4 | | -from jax import jit, pmap, local_device_count, custom_vjp |
| 4 | +from jax import jit, pmap, local_device_count |
5 | 5 | from functools import partial |
6 | 6 | from typing import List |
7 | 7 | from s2fft.sampling import s2_samples as samples |
@@ -379,16 +379,7 @@ def eval_recursion_step( |
379 | 379 | ).reshape(ntheta, ftm.shape[-1]) |
380 | 380 |
|
381 | 381 | else: |
382 | | - ( |
383 | | - ftm, |
384 | | - dl_entry, |
385 | | - dl_iter, |
386 | | - lrenorm, |
387 | | - indices, |
388 | | - omc, |
389 | | - c, |
390 | | - s, |
391 | | - ) = lax.fori_loop( |
| 382 | + (ftm, dl_entry, dl_iter, lrenorm, indices, omc, c, s,) = lax.fori_loop( |
392 | 383 | 2, |
393 | 384 | L - 1 + i, |
394 | 385 | pm_recursion_step, |
@@ -818,10 +809,7 @@ def eval_recursion_step( |
818 | 809 | opsdevice = int((L - L_lower) / ndevices) |
819 | 810 |
|
820 | 811 | flm = flm.at[L_lower:].set( |
821 | | - pmap( |
822 | | - eval_recursion_step, |
823 | | - in_axes=(0, 1, 2, 2, 1, 1, 1, 1), |
824 | | - )( |
| 812 | + pmap(eval_recursion_step, in_axes=(0, 1, 2, 2, 1, 1, 1, 1),)( |
825 | 813 | flm[L_lower:].reshape(ndevices, opsdevice, 2 * L - 1), |
826 | 814 | dl_entry.reshape(ntheta, ndevices, opsdevice), |
827 | 815 | dl_iter.reshape(2, ntheta, ndevices, opsdevice), |
|
0 commit comments