Skip to content

Commit 671e656

Browse files
committed
improve assert_allclose
1 parent fc5b705 commit 671e656

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

tests/utils/ops.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,35 @@
11
import keras
2-
3-
4-
def isclose(x1, x2, rtol=1e-5, atol=1e-5):
5-
return keras.ops.abs(x1 - x2) <= atol + rtol * keras.ops.abs(x2)
2+
import numpy as np
63

74

85
def allclose(x1, x2, rtol=1e-5, atol=1e-5):
9-
return keras.ops.all(isclose(x1, x2, rtol, atol))
6+
return keras.ops.all(keras.ops.isclose(x1, x2, rtol, atol))
107

118

129
def assert_allclose(x1, x2, rtol=1e-5, atol=1e-8, msg=""):
13-
mse = keras.ops.mean(keras.ops.square(x1 - x2))
14-
assert allclose(x1, x2, rtol, atol), f"{msg} - mse={mse}"
10+
x1 = keras.ops.convert_to_numpy(x1)
11+
x2 = keras.ops.convert_to_numpy(x2)
12+
13+
assert x1.shape == x2.shape, "Input shapes do not match."
14+
15+
mse = np.mean(np.square(x1 - x2)).item()
16+
largest_deviation = np.max(np.abs(x1 - x2)).item()
17+
largest_deviation_index = np.unravel_index(np.argmax(np.abs(x1 - x2)), x1.shape)
18+
largest_deviation_value1 = x1[largest_deviation_index].item()
19+
largest_deviation_value2 = x2[largest_deviation_index].item()
20+
21+
if msg:
22+
msg = f"{msg}\n"
23+
else:
24+
msg = "Inputs significantly differ:\n"
25+
26+
msg += "Largest Deviation:\n"
27+
msg += f"|{largest_deviation_value1:.02e} - {largest_deviation_value2:.02e}| = {largest_deviation:.02e}\n"
28+
msg += "\n"
29+
msg += "MSE:\n"
30+
msg += f"{mse:.02e}"
31+
32+
assert allclose(x1, x2, rtol, atol), msg
1533

1634

1735
def max_mean_discrepancy(x, y):

0 commit comments

Comments
 (0)