|
6 | 6 | from pylops import MatrixMult, FirstDerivative |
7 | 7 | from pylops.utils.signalprocessing import convmtx, nonstationary_convmtx |
8 | 8 | from pylops.signalprocessing import Convolve1D |
9 | | -#from pylops.avo.poststack import _PoststackLinearModelling |
| 9 | +from pylops.avo.poststack import _PoststackLinearModelling |
10 | 10 |
|
11 | 11 | from pylops_gpu.utils import dottest as Dottest |
| 12 | +from pylops_gpu.utils.torch2numpy import torchtype_from_numpytype |
12 | 13 | from pylops_gpu import MatrixMult as gMatrixMult |
13 | 14 | from pylops_gpu import FirstDerivative as gFirstDerivative |
14 | 15 | from pylops_gpu import SecondDerivative as gSecondDerivative |
|
19 | 20 |
|
20 | 21 | logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.WARNING) |
21 | 22 |
|
22 | | - |
| 23 | +""" |
23 | 24 | def _PoststackLinearModelling(wav, nt0, spatdims=None, explicit=False, |
24 | 25 | sparse=False, _MatrixMult=MatrixMult, |
25 | 26 | _Convolve1D=Convolve1D, |
26 | 27 | _FirstDerivative=FirstDerivative, |
27 | 28 | args_MatrixMult={}, args_Convolve1D={}, |
28 | 29 | args_FirstDerivative={}): |
| 30 | + # define dtype to be used (ensure wav.dtype rules that of operator) |
| 31 | + dtype = torchtype_from_numpytype(wav.dtype) |
29 | 32 |
|
30 | | - |
31 | | - """Post-stack linearized seismic modelling operator. |
32 | | -
|
33 | | - Used to be able to provide operators from different libraries to |
34 | | - PoststackLinearModelling. It operates in the same way as public method |
35 | | - (PoststackLinearModelling) but has additional input parameters allowing |
36 | | - passing a different operator and additional arguments to be passed to such |
37 | | - operator. |
38 | | -
|
39 | | - """ |
40 | 33 | if len(wav.shape) == 2 and wav.shape[0] != nt0: |
41 | 34 | raise ValueError('Provide 1d wavelet or 2d wavelet composed of nt0 ' |
42 | 35 | 'wavelets') |
@@ -67,23 +60,27 @@ def _PoststackLinearModelling(wav, nt0, spatdims=None, explicit=False, |
67 | 60 | M = np.dot(C, D) |
68 | 61 | if sparse: |
69 | 62 | M = csc_matrix(M) |
70 | | - Pop = _MatrixMult(M, dims=spatdims, **args_MatrixMult) |
| 63 | + Pop = _MatrixMult(M, dims=spatdims, dtype=dtype, **args_MatrixMult) |
71 | 64 | else: |
72 | 65 | # Create wavelet operator |
73 | 66 | if len(wav.shape) == 1: |
74 | 67 | Cop = _Convolve1D(np.prod(np.array(dims)), h=wav, |
75 | 68 | offset=len(wav) // 2, dir=0, dims=dims, |
76 | | - **args_Convolve1D) |
| 69 | + dtype=dtype, **args_Convolve1D) |
77 | 70 | else: |
78 | 71 | Cop = _MatrixMult(nonstationary_convmtx(wav, nt0, |
79 | 72 | hc=wav.shape[1] // 2, |
80 | 73 | pad=(nt0, nt0)), |
81 | | - dims=spatdims, **args_MatrixMult) |
| 74 | + dims=spatdims, dtype=dtype, |
| 75 | + **args_MatrixMult) |
82 | 76 | # Create derivative operator |
83 | 77 | Dop = _FirstDerivative(np.prod(np.array(dims)), dims=dims, |
84 | | - dir=0, sampling=1., **args_FirstDerivative) |
| 78 | + dir=0, sampling=1., dtype=dtype, |
| 79 | + **args_FirstDerivative) |
| 80 | +
|
85 | 81 | Pop = Cop * Dop |
86 | 82 | return Pop |
| 83 | +""" |
87 | 84 |
|
88 | 85 |
|
89 | 86 | def PoststackLinearModelling(wav, nt0, spatdims=None, explicit=False, |
@@ -129,8 +126,9 @@ def PoststackLinearModelling(wav, nt0, spatdims=None, explicit=False, |
129 | 126 | implementation details. |
130 | 127 |
|
131 | 128 | """ |
132 | | - if not isinstance(wav, torch.Tensor) and not explicit: |
133 | | - wav = torch.from_numpy(wav).to(device) |
| 129 | + # ensure wav is always numpy, it will be converted later back to torch |
| 130 | + if isinstance(wav, torch.Tensor): |
| 131 | + wav = wav.cpu().numpy() |
134 | 132 | return _PoststackLinearModelling(wav, nt0, spatdims=spatdims, |
135 | 133 | explicit=explicit, sparse=False, |
136 | 134 | _MatrixMult=gMatrixMult, |
|
0 commit comments