@@ -110,50 +110,17 @@ def dottest(
110110 # complex numbers in subsequent prints also when using cupy arrays.
111111 xx , yy = np .array ([to_numpy (xx )])[0 ], np .array ([to_numpy (yy )])[0 ]
112112
113- # evaluate if dot test is passed
114- if complexflag == 0 :
115- if np .abs ((yy - xx ) / ((yy + xx + 1e-15 ) / 2 )) < tol :
116- if verb :
117- print ("Dot test passed, v^T(Opu)=%f - u^T(Op^Tv)=%f" % (yy , xx ))
118- return True
119- else :
120- if raiseerror :
121- raise ValueError (
122- "Dot test failed, v^T(Opu)=%f - u^T(Op^Tv)=%f" % (yy , xx )
123- )
124- if verb :
125- print ("Dot test failed, v^T(Opu)=%f - u^T(Op^Tv)=%f" % (yy , xx ))
126- return False
127- else :
128- # Check both real and imag parts
129- checkreal = (
130- np .abs (
131- (np .real (yy ) - np .real (xx )) / ((np .real (yy ) + np .real (xx ) + 1e-15 ) / 2 )
132- )
133- < tol
134- )
135- checkimag = (
136- np .abs (
137- (np .imag (yy ) - np .imag (xx )) / ((np .imag (yy ) + np .imag (xx ) + 1e-15 ) / 2 )
138- )
139- < tol
140- )
141- if checkreal and checkimag :
142- if verb :
143- print (
144- "Dot test passed, v^T(Opu)=%f%+fi - u^T(Op^Tv)=%f%+fi"
145- % (yy .real , yy .imag , xx .real , xx .imag )
146- )
147- return True
148- else :
149- if raiseerror :
150- raise ValueError (
151- "Dot test failed, v^H(Opu)=%f%+fi "
152- "- u^H(Op^Hv)=%f%+fi" % (yy .real , yy .imag , xx .real , xx .imag )
153- )
154- if verb :
155- print (
156- "Dot test failed, v^H(Opu)=%f%+fi - u^H(Op^Hv)=%f%+fi"
157- % (yy .real , yy .imag , xx .real , xx .imag )
158- )
159- return False
113+ def passes (xx , yy , tol = tol ):
114+ """True if xx and yy are the same within tolerance."""
115+ return abs ((yy - xx ) / ((yy + xx + 1e-15 ) / 2 )) < tol
116+
117+ # evaluate if dot test is passed (both real and imag parts)
118+ passed = passes (np .real (xx ), np .real (yy )) and passes (np .imag (xx ), np .imag (yy ))
119+
120+ if not passed and raiseerror :
121+ raise ValueError (f"Dot test failed, v^H(Opu)={ yy } - u^H(Op^Hv)={ xx } " )
122+ elif verb :
123+ passed_status = "passed" if passed else "failed"
124+ print (f"Dot test { passed_status } , v^H(Opu)={ yy } - u^H(Op^Hv)={ xx } " )
125+
126+ return passed
0 commit comments