|
1 | 1 | import keras |
2 | | - |
3 | | - |
4 | | -def isclose(x1, x2, rtol=1e-5, atol=1e-5): |
5 | | - return keras.ops.abs(x1 - x2) <= atol + rtol * keras.ops.abs(x2) |
| 2 | +import numpy as np |
6 | 3 |
|
7 | 4 |
|
8 | 5 | def allclose(x1, x2, rtol=1e-5, atol=1e-5): |
9 | | - return keras.ops.all(isclose(x1, x2, rtol, atol)) |
| 6 | + return keras.ops.all(keras.ops.isclose(x1, x2, rtol, atol)) |
10 | 7 |
|
11 | 8 |
|
12 | 9 | def assert_allclose(x1, x2, rtol=1e-5, atol=1e-8, msg=""): |
13 | | - mse = keras.ops.mean(keras.ops.square(x1 - x2)) |
14 | | - assert allclose(x1, x2, rtol, atol), f"{msg} - mse={mse}" |
| 10 | + x1 = keras.ops.convert_to_numpy(x1) |
| 11 | + x2 = keras.ops.convert_to_numpy(x2) |
| 12 | + |
| 13 | + assert x1.shape == x2.shape, "Input shapes do not match." |
| 14 | + |
| 15 | + mse = np.mean(np.square(x1 - x2)).item() |
| 16 | + largest_deviation = np.max(np.abs(x1 - x2)).item() |
| 17 | + largest_deviation_index = np.unravel_index(np.argmax(np.abs(x1 - x2)), x1.shape) |
| 18 | + largest_deviation_value1 = x1[largest_deviation_index].item() |
| 19 | + largest_deviation_value2 = x2[largest_deviation_index].item() |
| 20 | + |
| 21 | + if msg: |
| 22 | + msg = f"{msg}\n" |
| 23 | + else: |
| 24 | + msg = "Inputs significantly differ:\n" |
| 25 | + |
| 26 | + msg += "Largest Deviation:\n" |
| 27 | + msg += f"|{largest_deviation_value1:.02e} - {largest_deviation_value2:.02e}| = {largest_deviation:.02e}\n" |
| 28 | + msg += "\n" |
| 29 | + msg += "MSE:\n" |
| 30 | + msg += f"{mse:.02e}" |
| 31 | + |
| 32 | + assert allclose(x1, x2, rtol, atol), msg |
15 | 33 |
|
16 | 34 |
|
17 | 35 | def max_mean_discrepancy(x, y): |
|
0 commit comments