Skip to content

Commit e588fe3

Browse files
authored
Merge pull request #566 from mrava87/feat-unittaper
feat: added exponent to cosinetaper
2 parents a6d09c1 + 38fe5ee commit e588fe3

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

pylops/utils/tapers.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"tapernd",
88
]
99

10-
from typing import Tuple, Union
10+
from typing import Optional, Tuple, Union
1111

1212
import numpy as np
1313
import numpy.typing as npt
@@ -59,6 +59,7 @@ def cosinetaper(
5959
nmask: int,
6060
ntap: int,
6161
square: bool = False,
62+
exponent: Optional[float] = None,
6263
) -> npt.ArrayLike:
6364
r"""1D Cosine or Cosine square taper
6465
@@ -71,8 +72,10 @@ def cosinetaper(
7172
Number of samples of mask
7273
ntap : :obj:`int`
7374
Number of samples of hanning tapering at edges
74-
square : :obj:`bool`
75-
Cosine square taper (``True``)or Cosine taper (``False``)
75+
square : :obj:`bool`, optional
76+
Cosine square taper (``True``) or Cosine taper (``False``)
77+
exponent : :obj:`float`, optional
78+
Exponent to apply to Cosine taper. If provided, takes precedence over ``square``
7679
7780
Returns
7881
-------
@@ -81,7 +84,8 @@ def cosinetaper(
8184
8285
"""
8386
ntap = 0 if ntap == 1 else ntap
84-
exponent = 1 if not square else 2
87+
if exponent is None:
88+
exponent = 1 if not square else 2
8589
cos_win = (
8690
0.5
8791
* (
@@ -123,7 +127,8 @@ def taper(
123127
ntap : :obj:`int`
124128
Number of samples of hanning tapering at edges
125129
tapertype : :obj:`str`, optional
126-
Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
130+
Type of taper (``hanning``, ``cosine``,
131+
``cosinesquare``, ``cosinesqrt`` or ``None``)
127132
128133
Returns
129134
-------
@@ -137,6 +142,8 @@ def taper(
137142
tpr_1d = cosinetaper(nmask, ntap, False)
138143
elif tapertype == "cosinesquare":
139144
tpr_1d = cosinetaper(nmask, ntap, True)
145+
elif tapertype == "cosinesqrt":
146+
tpr_1d = cosinetaper(nmask, ntap, False, 0.5)
140147
else:
141148
tpr_1d = np.ones(nmask)
142149
return tpr_1d
@@ -214,7 +221,7 @@ def taper3d(
214221
Number of samples of tapering at edges of first and second dimensions
215222
tapertype : :obj:`int`
216223
Type of taper (``hanning``, ``cosine``,
217-
``cosinesquare`` or ``None``)
224+
``cosinesquare``, ``cosinesqrt`` or ``None``)
218225
219226
Returns
220227
-------
@@ -236,6 +243,9 @@ def taper3d(
236243
elif tapertype == "cosinesquare":
237244
tpr_y = cosinetaper(nmasky, ntapy, True)
238245
tpr_x = cosinetaper(nmaskx, ntapx, True)
246+
elif tapertype == "cosinesqrt":
247+
tpr_y = cosinetaper(nmasky, ntapy, False, 0.5)
248+
tpr_x = cosinetaper(nmaskx, ntapx, False, 0.5)
239249
else:
240250
tpr_y = np.ones(nmasky)
241251
tpr_x = np.ones(nmaskx)
@@ -266,7 +276,7 @@ def tapernd(
266276
Number of samples of tapering at edges of every dimension
267277
tapertype : :obj:`int`
268278
Type of taper (``hanning``, ``cosine``,
269-
``cosinesquare`` or ``None``)
279+
``cosinesquare``, ``cosinesqrt`` or ``None``)
270280
271281
Returns
272282
-------
@@ -282,6 +292,8 @@ def tapernd(
282292
tpr = [cosinetaper(nm, nt, False) for nm, nt in zip(nmask, ntap)]
283293
elif tapertype == "cosinesquare":
284294
tpr = [cosinetaper(nm, nt, True) for nm, nt in zip(nmask, ntap)]
295+
elif tapertype == "cosinesqrt":
296+
tpr = [cosinetaper(nm, nt, False, 0.5) for nm, nt in zip(nmask, ntap)]
285297
else:
286298
tpr = [np.ones(nm) for nm in nmask]
287299

pytests/test_tapers.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,23 @@
4040
"ntap": (4, 6),
4141
"tapertype": "cosinesquare",
4242
} # cosinesquare, even samples and taper
43+
par7 = {
44+
"nt": 21,
45+
"nspat": (11, 13),
46+
"ntap": (3, 5),
47+
"tapertype": "cosinesqrt",
48+
} # cosinesqrt, odd samples and taper
49+
par8 = {
50+
"nt": 20,
51+
"nspat": (12, 16),
52+
"ntap": (4, 6),
53+
"tapertype": "cosinesqrt",
54+
} # cosinesqrt, even samples and taper
4355

4456

45-
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
57+
@pytest.mark.parametrize(
58+
"par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]
59+
)
4660
def test_taper2d(par):
4761
"""Create taper wavelet and check size and values"""
4862
tap = taper2d(par["nt"], par["nspat"][0], par["ntap"][0], par["tapertype"])
@@ -54,7 +68,9 @@ def test_taper2d(par):
5468
assert_array_equal(tap[par["nspat"][0] // 2], np.ones(par["nt"]))
5569

5670

57-
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
71+
@pytest.mark.parametrize(
72+
"par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]
73+
)
5874
def test_taper3d(par):
5975
"""Create taper wavelet and check size and values"""
6076
tap = taper3d(par["nt"], par["nspat"], par["ntap"], par["tapertype"])

0 commit comments

Comments
 (0)