@@ -405,17 +405,19 @@ def MDD(
405405
406406 # Add negative part to data and model
407407 if twosided and add_negative :
408- G = np .concatenate ((ncp .zeros ((ns , nr , nt - 1 )), G ), axis = - 1 )
409- d = np .concatenate ((np .squeeze (np .zeros ((ns , nv , nt - 1 ))), d ), axis = - 1 )
408+ G = ncp .concatenate ((ncp .zeros ((ns , nr , nt - 1 ), dtype = G .dtype ), G ), axis = - 1 )
409+ d = ncp .concatenate (
410+ (ncp .squeeze (ncp .zeros ((ns , nv , nt - 1 ), dtype = d .dtype )), d ), axis = - 1
411+ )
410412 # Bring kernel to frequency domain
411- Gfft = np .fft .rfft (G , nt2 , axis = - 1 )
413+ Gfft = ncp .fft .rfft (G , nt2 , axis = - 1 )
412414 Gfft = Gfft [..., :nfmax ]
413415
414416 # Bring frequency/time to first dimension
415- Gfft = np .moveaxis (Gfft , - 1 , 0 )
416- d = np .moveaxis (d , - 1 , 0 )
417+ Gfft = ncp .moveaxis (Gfft , - 1 , 0 )
418+ d = ncp .moveaxis (d , - 1 , 0 )
417419 if psf :
418- G = np .moveaxis (G , - 1 , 0 )
420+ G = ncp .moveaxis (G , - 1 , 0 )
419421
420422 # Define MDC linear operator
421423 MDCop = MDC (
@@ -455,12 +457,12 @@ def MDD(
455457 # Adjoint
456458 if adjoint :
457459 madj = MDCop .H * d .ravel ()
458- madj = np .squeeze (madj .reshape (nt2 , nr , nv ))
459- madj = np .moveaxis (madj , 0 , - 1 )
460+ madj = ncp .squeeze (madj .reshape (nt2 , nr , nv ))
461+ madj = ncp .moveaxis (madj , 0 , - 1 )
460462 if psf :
461463 psfadj = PSFop .H * G .ravel ()
462- psfadj = np .squeeze (psfadj .reshape (nt2 , nr , nr ))
463- psfadj = np .moveaxis (psfadj , 0 , - 1 )
464+ psfadj = ncp .squeeze (psfadj .reshape (nt2 , nr , nr ))
465+ psfadj = ncp .moveaxis (psfadj , 0 , - 1 )
464466
465467 # Inverse
466468 if twosided and causality_precond :
@@ -481,8 +483,8 @@ def MDD(
481483 ncp .zeros (int (MDCop .shape [1 ]), dtype = MDCop .dtype ),
482484 ** kwargs_solver
483485 )[0 ]
484- minv = np .squeeze (minv .reshape (nt2 , nr , nv ))
485- minv = np .moveaxis (minv , 0 , - 1 )
486+ minv = ncp .squeeze (minv .reshape (nt2 , nr , nv ))
487+ minv = ncp .moveaxis (minv , 0 , - 1 )
486488
487489 if wav is not None :
488490 wav1 = wav .copy ()
@@ -500,8 +502,8 @@ def MDD(
500502 ncp .zeros (int (PSFop .shape [1 ]), dtype = PSFop .dtype ),
501503 ** kwargs_solver
502504 )[0 ]
503- psfinv = np .squeeze (psfinv .reshape (nt2 , nr , nr ))
504- psfinv = np .moveaxis (psfinv , 0 , - 1 )
505+ psfinv = ncp .squeeze (psfinv .reshape (nt2 , nr , nr ))
506+ psfinv = ncp .moveaxis (psfinv , 0 , - 1 )
505507 if wav is not None :
506508 wav1 = wav .copy ()
507509 for _ in range (psfinv .ndim - 1 ):
0 commit comments