|
1 | 1 | __all__ = ["AcousticWave2D"] |
2 | 2 |
|
3 | | -from typing import Tuple |
| 3 | +from typing import Any, NewType, Tuple |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 |
|
|
14 | 14 | if devito_message is None: |
15 | 15 | from examples.seismic import AcquisitionGeometry, Model |
16 | 16 | from examples.seismic.acoustic import AcousticWaveSolver |
17 | | - from examples.seismic.utils import PointSource |
18 | 17 |
|
| 18 | + from ._twoway import _CustomSource |
| 19 | +else: |
| 20 | + AcousticWaveSolver = Any |
19 | 21 |
|
20 | | -class _CustomSource(PointSource): |
21 | | - """Custom source |
22 | | -
|
23 | | - This class creates a Devito symbolic object that encapsulates a set of |
24 | | - sources with a user defined source signal wavelet ``wav`` |
25 | | -
|
26 | | - Parameters |
27 | | - ---------- |
28 | | - name : :obj:`str` |
29 | | - Name for the resulting symbol. |
30 | | - grid : :obj:`devito.types.grid.Grid` |
31 | | - The computational domain. |
32 | | - time_range : :obj:`examples.seismic.source.TimeAxis` |
33 | | - TimeAxis(start, step, num) object. |
34 | | - wav : :obj:`numpy.ndarray` |
35 | | - Wavelet of size |
36 | | -
|
37 | | - """ |
38 | | - |
39 | | - __rkwargs__ = PointSource.__rkwargs__ + ["wav"] |
40 | | - |
41 | | - @classmethod |
42 | | - def __args_setup__(cls, *args, **kwargs): |
43 | | - kwargs.setdefault("npoint", 1) |
44 | | - |
45 | | - return super().__args_setup__(*args, **kwargs) |
46 | | - |
47 | | - def __init_finalize__(self, *args, **kwargs): |
48 | | - super().__init_finalize__(*args, **kwargs) |
49 | | - |
50 | | - self.wav = kwargs.get("wav") |
51 | | - |
52 | | - if not self.alias: |
53 | | - for p in range(kwargs["npoint"]): |
54 | | - self.data[:, p] = self.wavelet |
55 | | - |
56 | | - @property |
57 | | - def wavelet(self): |
58 | | - """Return user-provided wavelet""" |
59 | | - return self.wav |
| 22 | +AcousticWaveSolverType = NewType("AcousticWaveSolver", AcousticWaveSolver) |
60 | 23 |
|
61 | 24 |
|
62 | 25 | class AcousticWave2D(LinearOperator): |
@@ -327,7 +290,7 @@ def srcillumination_allshots(self, savewav: bool = False) -> None: |
327 | 290 | self.src_wavefield.append(src_wav) |
328 | 291 | self.src_illumination += src_ill |
329 | 292 |
|
330 | | - def _born_oneshot(self, solver: AcousticWaveSolver, dm: NDArray) -> NDArray: |
| 293 | + def _born_oneshot(self, solver: AcousticWaveSolverType, dm: NDArray) -> NDArray: |
331 | 294 | """Born modelling for one shot |
332 | 295 |
|
333 | 296 | Parameters |
|
0 commit comments