Skip to content

Commit db3b339

Browse files
committed
Modified poststack to use pylops operator
Remove _PoststackLinearModelling and rely on pylops one. This is possible given changes in previous commit which allow using numpy dtypes in pylops-gpu operators.
1 parent 51630f8 commit db3b339

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

pylops_gpu/avo/poststack.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from pylops import MatrixMult, FirstDerivative
77
from pylops.utils.signalprocessing import convmtx, nonstationary_convmtx
88
from pylops.signalprocessing import Convolve1D
9-
#from pylops.avo.poststack import _PoststackLinearModelling
9+
from pylops.avo.poststack import _PoststackLinearModelling
1010

1111
from pylops_gpu.utils import dottest as Dottest
12+
from pylops_gpu.utils.torch2numpy import torchtype_from_numpytype
1213
from pylops_gpu import MatrixMult as gMatrixMult
1314
from pylops_gpu import FirstDerivative as gFirstDerivative
1415
from pylops_gpu import SecondDerivative as gSecondDerivative
@@ -19,24 +20,16 @@
1920

2021
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.WARNING)
2122

22-
23+
"""
2324
def _PoststackLinearModelling(wav, nt0, spatdims=None, explicit=False,
2425
sparse=False, _MatrixMult=MatrixMult,
2526
_Convolve1D=Convolve1D,
2627
_FirstDerivative=FirstDerivative,
2728
args_MatrixMult={}, args_Convolve1D={},
2829
args_FirstDerivative={}):
30+
# define dtype to be used (ensure wav.dtype rules that of operator)
31+
dtype = torchtype_from_numpytype(wav.dtype)
2932
30-
31-
"""Post-stack linearized seismic modelling operator.
32-
33-
Used to be able to provide operators from different libraries to
34-
PoststackLinearModelling. It operates in the same way as public method
35-
(PoststackLinearModelling) but has additional input parameters allowing
36-
passing a different operator and additional arguments to be passed to such
37-
operator.
38-
39-
"""
4033
if len(wav.shape) == 2 and wav.shape[0] != nt0:
4134
raise ValueError('Provide 1d wavelet or 2d wavelet composed of nt0 '
4235
'wavelets')
@@ -67,23 +60,27 @@ def _PoststackLinearModelling(wav, nt0, spatdims=None, explicit=False,
6760
M = np.dot(C, D)
6861
if sparse:
6962
M = csc_matrix(M)
70-
Pop = _MatrixMult(M, dims=spatdims, **args_MatrixMult)
63+
Pop = _MatrixMult(M, dims=spatdims, dtype=dtype, **args_MatrixMult)
7164
else:
7265
# Create wavelet operator
7366
if len(wav.shape) == 1:
7467
Cop = _Convolve1D(np.prod(np.array(dims)), h=wav,
7568
offset=len(wav) // 2, dir=0, dims=dims,
76-
**args_Convolve1D)
69+
dtype=dtype, **args_Convolve1D)
7770
else:
7871
Cop = _MatrixMult(nonstationary_convmtx(wav, nt0,
7972
hc=wav.shape[1] // 2,
8073
pad=(nt0, nt0)),
81-
dims=spatdims, **args_MatrixMult)
74+
dims=spatdims, dtype=dtype,
75+
**args_MatrixMult)
8276
# Create derivative operator
8377
Dop = _FirstDerivative(np.prod(np.array(dims)), dims=dims,
84-
dir=0, sampling=1., **args_FirstDerivative)
78+
dir=0, sampling=1., dtype=dtype,
79+
**args_FirstDerivative)
80+
8581
Pop = Cop * Dop
8682
return Pop
83+
"""
8784

8885

8986
def PoststackLinearModelling(wav, nt0, spatdims=None, explicit=False,
@@ -129,8 +126,9 @@ def PoststackLinearModelling(wav, nt0, spatdims=None, explicit=False,
129126
implementation details.
130127
131128
"""
132-
if not isinstance(wav, torch.Tensor) and not explicit:
133-
wav = torch.from_numpy(wav).to(device)
129+
# ensure wav is always numpy, it will be converted later back to torch
130+
if isinstance(wav, torch.Tensor):
131+
wav = wav.cpu().numpy()
134132
return _PoststackLinearModelling(wav, nt0, spatdims=spatdims,
135133
explicit=explicit, sparse=False,
136134
_MatrixMult=gMatrixMult,

0 commit comments

Comments
 (0)