@@ -255,43 +255,38 @@ def forward_transform(
255255 # the weights are conjugate but applied flipped and therefore are
256256 # equivalent. To avoid flipping here we simply conjugate the weights.
257257
258- # PRECOMPUTE TRANSFORM
259258 if precomps is not None :
260- # EXTRACT VARIOUS PRECOMPUTES
259+ # PRECOMPUTE TRANSFORM
261260 delta , quads = precomps
262-
263- # APPLY QUADRATURE
264- x = np .einsum ("nbm,b->nbm" , x , quads )
265-
266- # COMPUTE GMM BY FFT
267- x = np .fft .fft (x , axis = 1 , norm = "forward" )
268- x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
269-
270- # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
271- x = np .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
272-
273- # OTF TRANSFORM
274261 else :
262+ # OTF TRANSFORM
263+ delta = None
275264 # COMPUTE QUADRATURE WEIGHTS
276265 quads = np .zeros (4 * L - 3 , dtype = np .complex128 )
277266 for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
278267 quads [mm + 2 * (L - 1 )] = quadrature .mw_weights (- mm )
279268 quads = np .fft .ifft (np .fft .ifftshift (quads ), norm = "forward" )
280269
281- # APPLY QUADRATURE
282- x = np .einsum ("nbm,b->nbm" , x , quads )
270+ # APPLY QUADRATURE
271+ x = np .einsum ("nbm,b->nbm" , x , quads )
283272
284- # COMPUTE GMM BY FFT
285- x = np .fft .fft (x , axis = 1 , norm = "forward" )
286- x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
273+ # COMPUTE GMM BY FFT
274+ x = np .fft .fft (x , axis = 1 , norm = "forward" )
275+ x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
287276
288- # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
277+ # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
278+ if delta is not None :
279+ # PRECOMPUTE TRANSFORM
280+ x = np .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
281+ else :
282+ # OTF TRANSFORM
289283 delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
290284 xx = np .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
291285 for el in range (L ):
292286 delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
293287 xx [:, el ] = np .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
294288 x = xx
289+
295290 x = np .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
296291
297292 # SYMMETRY REFLECT FOR N < 0
@@ -381,35 +376,31 @@ def forward_transform_jax(
381376 # the weights are conjugate but applied flipped and therefore are
382377 # equivalent. To avoid flipping here we simply conjugate the weights.
383378
384- # PRECOMPUTE TRANSFORM
385379 if precomps is not None :
386- # EXTRACT VARIOUS PRECOMPUTES
380+ # PRECOMPUTE TRANSFORM
387381 delta , quads = precomps
388-
389- # APPLY QUADRATURE
390- x = jnp .einsum ("nbm,b->nbm" , x , quads )
391-
392- # COMPUTE GMM BY FFT
393- x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
394- x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
395-
396- # Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
397- x = jnp .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
398-
399382 else :
383+ # OTF TRANSFORM
384+ delta = None
385+ # COMPUTE QUADRATURE WEIGHTS
400386 quads = jnp .zeros (4 * L - 3 , dtype = jnp .complex128 )
401387 for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
402388 quads = quads .at [mm + 2 * (L - 1 )].set (quadrature_jax .mw_weights (- mm ))
403389 quads = jnp .fft .ifft (jnp .fft .ifftshift (quads ), norm = "forward" )
404390
405- # APPLY QUADRATURE
406- x = jnp .einsum ("nbm,b->nbm" , x , quads )
391+ # APPLY QUADRATURE
392+ x = jnp .einsum ("nbm,b->nbm" , x , quads )
407393
408- # COMPUTE GMM BY FFT
409- x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
410- x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
394+ # COMPUTE GMM BY FFT
395+ x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
396+ x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
411397
412- # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
398+ # Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
399+ if delta is not None :
400+ # PRECOMPUTE TRANSFORM
401+ x = jnp .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
402+ else :
403+ # OTF TRANSFORM
413404 delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
414405 xx = jnp .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
415406 for el in range (L ):
0 commit comments