Skip to content

Commit b9dab0d

Browse files
authored
Merge pull request #335 from prisae/dottest
Refactor dottest
2 parents c989a89 + d98a4cf commit b9dab0d

File tree

1 file changed

+14
-47
lines changed

1 file changed

+14
-47
lines changed

pylops/utils/dottest.py

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -110,50 +110,17 @@ def dottest(
110110
# complex numbers in subsequent prints also when using cupy arrays.
111111
xx, yy = np.array([to_numpy(xx)])[0], np.array([to_numpy(yy)])[0]
112112

113-
# evaluate if dot test is passed
114-
if complexflag == 0:
115-
if np.abs((yy - xx) / ((yy + xx + 1e-15) / 2)) < tol:
116-
if verb:
117-
print("Dot test passed, v^T(Opu)=%f - u^T(Op^Tv)=%f" % (yy, xx))
118-
return True
119-
else:
120-
if raiseerror:
121-
raise ValueError(
122-
"Dot test failed, v^T(Opu)=%f - u^T(Op^Tv)=%f" % (yy, xx)
123-
)
124-
if verb:
125-
print("Dot test failed, v^T(Opu)=%f - u^T(Op^Tv)=%f" % (yy, xx))
126-
return False
127-
else:
128-
# Check both real and imag parts
129-
checkreal = (
130-
np.abs(
131-
(np.real(yy) - np.real(xx)) / ((np.real(yy) + np.real(xx) + 1e-15) / 2)
132-
)
133-
< tol
134-
)
135-
checkimag = (
136-
np.abs(
137-
(np.imag(yy) - np.imag(xx)) / ((np.imag(yy) + np.imag(xx) + 1e-15) / 2)
138-
)
139-
< tol
140-
)
141-
if checkreal and checkimag:
142-
if verb:
143-
print(
144-
"Dot test passed, v^T(Opu)=%f%+fi - u^T(Op^Tv)=%f%+fi"
145-
% (yy.real, yy.imag, xx.real, xx.imag)
146-
)
147-
return True
148-
else:
149-
if raiseerror:
150-
raise ValueError(
151-
"Dot test failed, v^H(Opu)=%f%+fi "
152-
"- u^H(Op^Hv)=%f%+fi" % (yy.real, yy.imag, xx.real, xx.imag)
153-
)
154-
if verb:
155-
print(
156-
"Dot test failed, v^H(Opu)=%f%+fi - u^H(Op^Hv)=%f%+fi"
157-
% (yy.real, yy.imag, xx.real, xx.imag)
158-
)
159-
return False
113+
def passes(xx, yy, tol=tol):
114+
"""True if xx and yy are the same within tolerance."""
115+
return abs((yy - xx) / ((yy + xx + 1e-15) / 2)) < tol
116+
117+
# evaluate if dot test is passed (both real and imag parts)
118+
passed = passes(np.real(xx), np.real(yy)) and passes(np.imag(xx), np.imag(yy))
119+
120+
if not passed and raiseerror:
121+
raise ValueError(f"Dot test failed, v^H(Opu)={yy} - u^H(Op^Hv)={xx}")
122+
elif verb:
123+
passed_status = "passed" if passed else "failed"
124+
print(f"Dot test {passed_status}, v^H(Opu)={yy} - u^H(Op^Hv)={xx}")
125+
126+
return passed

0 commit comments

Comments
 (0)