Skip to content

Commit d62f4f4

Browse files
authored
Merge pull request #659 from mrava87/fix-mddcupy
bug: fix mdd to run with cupy when twosided=True and add_negative=True
2 parents a6d2c60 + 7a1b4c6 commit d62f4f4

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

pylops/waveeqprocessing/mdd.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -405,17 +405,19 @@ def MDD(
405405

406406
# Add negative part to data and model
407407
if twosided and add_negative:
408-
G = np.concatenate((ncp.zeros((ns, nr, nt - 1)), G), axis=-1)
409-
d = np.concatenate((np.squeeze(np.zeros((ns, nv, nt - 1))), d), axis=-1)
408+
G = ncp.concatenate((ncp.zeros((ns, nr, nt - 1), dtype=G.dtype), G), axis=-1)
409+
d = ncp.concatenate(
410+
(ncp.squeeze(ncp.zeros((ns, nv, nt - 1), dtype=d.dtype)), d), axis=-1
411+
)
410412
# Bring kernel to frequency domain
411-
Gfft = np.fft.rfft(G, nt2, axis=-1)
413+
Gfft = ncp.fft.rfft(G, nt2, axis=-1)
412414
Gfft = Gfft[..., :nfmax]
413415

414416
# Bring frequency/time to first dimension
415-
Gfft = np.moveaxis(Gfft, -1, 0)
416-
d = np.moveaxis(d, -1, 0)
417+
Gfft = ncp.moveaxis(Gfft, -1, 0)
418+
d = ncp.moveaxis(d, -1, 0)
417419
if psf:
418-
G = np.moveaxis(G, -1, 0)
420+
G = ncp.moveaxis(G, -1, 0)
419421

420422
# Define MDC linear operator
421423
MDCop = MDC(
@@ -455,12 +457,12 @@ def MDD(
455457
# Adjoint
456458
if adjoint:
457459
madj = MDCop.H * d.ravel()
458-
madj = np.squeeze(madj.reshape(nt2, nr, nv))
459-
madj = np.moveaxis(madj, 0, -1)
460+
madj = ncp.squeeze(madj.reshape(nt2, nr, nv))
461+
madj = ncp.moveaxis(madj, 0, -1)
460462
if psf:
461463
psfadj = PSFop.H * G.ravel()
462-
psfadj = np.squeeze(psfadj.reshape(nt2, nr, nr))
463-
psfadj = np.moveaxis(psfadj, 0, -1)
464+
psfadj = ncp.squeeze(psfadj.reshape(nt2, nr, nr))
465+
psfadj = ncp.moveaxis(psfadj, 0, -1)
464466

465467
# Inverse
466468
if twosided and causality_precond:
@@ -481,8 +483,8 @@ def MDD(
481483
ncp.zeros(int(MDCop.shape[1]), dtype=MDCop.dtype),
482484
**kwargs_solver
483485
)[0]
484-
minv = np.squeeze(minv.reshape(nt2, nr, nv))
485-
minv = np.moveaxis(minv, 0, -1)
486+
minv = ncp.squeeze(minv.reshape(nt2, nr, nv))
487+
minv = ncp.moveaxis(minv, 0, -1)
486488

487489
if wav is not None:
488490
wav1 = wav.copy()
@@ -500,8 +502,8 @@ def MDD(
500502
ncp.zeros(int(PSFop.shape[1]), dtype=PSFop.dtype),
501503
**kwargs_solver
502504
)[0]
503-
psfinv = np.squeeze(psfinv.reshape(nt2, nr, nr))
504-
psfinv = np.moveaxis(psfinv, 0, -1)
505+
psfinv = ncp.squeeze(psfinv.reshape(nt2, nr, nr))
506+
psfinv = ncp.moveaxis(psfinv, 0, -1)
505507
if wav is not None:
506508
wav1 = wav.copy()
507509
for _ in range(psfinv.ndim - 1):

0 commit comments

Comments
 (0)