Skip to content

Commit 8a36889

Browse files
authored
Merge pull request #336 from prisae/dottest-allclose
Check passed with np.isclose
2 parents 5bf12b0 + abc0982 commit 8a36889

File tree

11 files changed

+137
-138
lines changed

11 files changed

+137
-138
lines changed

MIGRATION_V1_V2.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ should be used as a checklist when converting a piece of code using PyLops from
88
- XX
99
- XX
1010
- XX
11+
12+
- `utils.dottest`: The relative tolerance is new set via `rtol` (before `tol`), and absolute tolerance is new supported via the keyword `atol`. When calling it with purely positional arguments, note that after `rtol` comes now first `atol` before `complexflag`. When using `raiseerror=True` it now emits an `AttributeError` instead of a `ValueError`.

examples/plot_nmo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def _rmatvec(self, y):
296296
# and adjoint transforms truly are adjoints of each other.
297297

298298
NMOOp = NMO(t, x, vel_t)
299-
dottest(NMOOp, *NMOOp.shape, tol=1e-4)
299+
dottest(NMOOp, *NMOOp.shape, rtol=1e-4)
300300

301301
###############################################################################
302302
# NMO using :py:class:`pylops.Spread`
@@ -370,7 +370,7 @@ def create_tables(taxis, haxis, vels_rms):
370370
dtable=nmo_dtable, # Table of weights for linear interpolation
371371
engine="numba", # numba or numpy
372372
).H # To perform NMO *correction*, we need the adjoint
373-
dottest(SpreadNMO, *SpreadNMO.shape, tol=1e-4)
373+
dottest(SpreadNMO, *SpreadNMO.shape, rtol=1e-4)
374374

375375
###############################################################################
376376
# We see it passes the dot test, but are the results right? Let's find out.

pylops/utils/dottest.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ def dottest(
77
Op,
88
nr=None,
99
nc=None,
10-
tol=1e-6,
10+
rtol=1e-6,
11+
atol=1e-21,
1112
complexflag=0,
1213
raiseerror=True,
1314
verb=False,
@@ -28,8 +29,11 @@ def dottest(
2829
Number of rows of operator (i.e., elements in data)
2930
nc : :obj:`int`
3031
Number of columns of operator (i.e., elements in model)
31-
tol : :obj:`float`, optional
32-
Dottest tolerance
32+
rtol : :obj:`float`, optional
33+
Relative dottest tolerance
34+
atol : :obj:`float`, optional
35+
Absolute dottest tolerance
36+
.. versionadded:: 2.0.0
3337
complexflag : :obj:`bool`, optional
3438
Generate random vectors with
3539
@@ -50,8 +54,8 @@ def dottest(
5054
5155
Raises
5256
------
53-
ValueError
54-
If dot-test is not verified within chosen tolerance.
57+
AssertionError
58+
If dot-test is not verified within chosen tolerances.
5559
5660
Notes
5761
-----
@@ -74,25 +78,19 @@ def dottest(
7478
if nc is None:
7579
nc = Op.shape[1]
7680

77-
assert (nr, nc) == Op.shape, "Provided nr and nc do not match operator shape"
81+
if (nr, nc) != Op.shape:
82+
raise AssertionError("Provided nr and nc do not match operator shape")
7883

7984
# make u and v vectors
80-
if complexflag != 0:
81-
rdtype = np.real(np.ones(1, Op.dtype)).dtype
85+
rdtype = np.ones(1, Op.dtype).real.dtype
8286

83-
if complexflag in (0, 2):
84-
u = ncp.random.randn(nc).astype(Op.dtype)
85-
else:
86-
u = ncp.random.randn(nc).astype(rdtype) + 1j * ncp.random.randn(nc).astype(
87-
rdtype
88-
)
87+
u = ncp.random.randn(nc).astype(rdtype)
88+
if complexflag not in (0, 2):
89+
u = u + 1j * ncp.random.randn(nc).astype(rdtype)
8990

90-
if complexflag in (0, 1):
91-
v = ncp.random.randn(nr).astype(Op.dtype)
92-
else:
93-
v = ncp.random.randn(nr).astype(rdtype) + 1j * ncp.random.randn(nr).astype(
94-
rdtype
95-
)
91+
v = ncp.random.randn(nr).astype(rdtype)
92+
if complexflag not in (0, 1):
93+
v = v + 1j * ncp.random.randn(nr).astype(rdtype)
9694

9795
y = Op.matvec(u) # Op * u
9896
x = Op.rmatvec(v) # Op'* v
@@ -110,17 +108,16 @@ def dottest(
110108
# complex numbers in subsequent prints also when using cupy arrays.
111109
xx, yy = np.array([to_numpy(xx)])[0], np.array([to_numpy(yy)])[0]
112110

113-
def passes(xx, yy, tol=tol):
114-
"""True if xx and yy are the same within tolerance."""
115-
return abs((yy - xx) / ((yy + xx + 1e-15) / 2)) < tol
116-
117-
# evaluate if dot test is passed (both real and imag parts)
118-
passed = passes(np.real(xx), np.real(yy)) and passes(np.imag(xx), np.imag(yy))
111+
# evaluate if dot test passed
112+
passed = np.isclose(xx, yy, rtol, atol)
119113

120-
if not passed and raiseerror:
121-
raise ValueError(f"Dot test failed, v^H(Opu)={yy} - u^H(Op^Hv)={xx}")
122-
elif verb:
114+
# verbosity or error raising
115+
if (not passed and raiseerror) or verb:
123116
passed_status = "passed" if passed else "failed"
124-
print(f"Dot test {passed_status}, v^H(Opu)={yy} - u^H(Op^Hv)={xx}")
117+
msg = f"Dot test {passed_status}, v^H(Opu)={yy} - u^H(Op^Hv)={xx}"
118+
if not passed and raiseerror:
119+
raise AssertionError(msg)
120+
else:
121+
print(msg)
125122

126123
return passed

0 commit comments

Comments
 (0)