Skip to content

Commit e051a48

Browse files
committed
fix numerical accuracy in adapter test for simple transforms
1 parent 1bfa9a5 commit e051a48

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

tests/test_adapters/test_adapters.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,16 @@ def test_simple_transforms(random_data):
9797

9898
result = ad(random_data)
9999

100-
assert np.array_equal(result["p2"], np.log(random_data["p2"]))
101-
assert np.array_equal(result["t2"], np.log(random_data["t2"]))
102-
assert np.array_equal(result["t1"], np.log1p(random_data["t1"]))
103-
assert np.array_equal(result["p1"], np.sqrt(random_data["p1"]))
100+
assert np.allclose(result["p2"], np.log(random_data["p2"]))
101+
assert np.allclose(result["t2"], np.log(random_data["t2"]))
102+
assert np.allclose(result["t1"], np.log1p(random_data["t1"]))
103+
assert np.allclose(result["p1"], np.sqrt(random_data["p1"]))
104104

105105
# inverse results should match the original input
106106
inverse = ad(result, inverse=True)
107107

108-
assert np.array_equal(inverse["p2"], random_data["p2"])
109-
assert np.array_equal(inverse["t2"], random_data["t2"])
110-
assert np.array_equal(inverse["t1"], random_data["t1"])
108+
assert np.allclose(inverse["p2"], random_data["p2"])
109+
assert np.allclose(inverse["t2"], random_data["t2"])
110+
assert np.allclose(inverse["t1"], random_data["t1"])
111111

112-
# numerical inaccuracies prevent np.array_equal to work here
113112
assert np.allclose(inverse["p1"], random_data["p1"])

0 commit comments

Comments
 (0)