|
3 | 3 |
|
4 | 4 | # Import the dispatch functions |
5 | 5 | from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net |
| 6 | +from tests.utils import assert_allclose |
6 | 7 |
|
7 | 8 | # --- Tests for find_network.py --- |
8 | 9 |
|
@@ -118,23 +119,21 @@ def test_find_pooling_mean(): |
118 | 119 | # Check that a keras Lambda layer is returned |
119 | 120 | assert isinstance(pooling, keras.layers.Lambda) |
120 | 121 | # Test that the lambda function produces a mean when applied to a sample tensor. |
121 | | - import numpy as np |
122 | 122 |
|
123 | | - sample = np.array([[1, 2], [3, 4]]) |
| 123 | + sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]]) |
124 | 124 | # Keras Lambda layers expect tensors via call(), here we simply call the layer's function. |
125 | 125 | result = pooling.call(sample) |
126 | | - np.testing.assert_allclose(result, sample.mean(axis=-2)) |
| 126 | + assert_allclose(result, keras.ops.mean(sample, axis=-2)) |
127 | 127 |
|
128 | 128 |
|
129 | 129 | @pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)]) |
130 | 130 | def test_find_pooling_max_min(name, func): |
131 | 131 | pooling = find_pooling(name) |
132 | 132 | assert isinstance(pooling, keras.layers.Lambda) |
133 | | - import numpy as np |
134 | 133 |
|
135 | | - sample = np.array([[1, 2], [3, 4]]) |
| 134 | + sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]]) |
136 | 135 | result = pooling.call(sample) |
137 | | - np.testing.assert_allclose(result, func(sample, axis=-2)) |
| 136 | + assert_allclose(result, func(sample, axis=-2)) |
138 | 137 |
|
139 | 138 |
|
140 | 139 | def test_find_pooling_learnable(monkeypatch): |
|
0 commit comments