Skip to content

Commit 0e54bf7

Browse files
committed
fix numpy-keras interop in tests
1 parent 988f6a4 commit 0e54bf7

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tests/test_utils/test_dispatch.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# Import the dispatch functions
55
from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net
6+
from tests.utils import assert_allclose
67

78
# --- Tests for find_network.py ---
89

@@ -118,23 +119,21 @@ def test_find_pooling_mean():
118119
# Check that a keras Lambda layer is returned
119120
assert isinstance(pooling, keras.layers.Lambda)
120121
# Test that the lambda function produces a mean when applied to a sample tensor.
121-
import numpy as np
122122

123-
sample = np.array([[1, 2], [3, 4]])
123+
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
124124
# Keras Lambda layers expect tensors via call(), here we simply call the layer's function.
125125
result = pooling.call(sample)
126-
np.testing.assert_allclose(result, sample.mean(axis=-2))
126+
assert_allclose(result, keras.ops.mean(sample, axis=-2))
127127

128128

129129
@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)])
130130
def test_find_pooling_max_min(name, func):
131131
pooling = find_pooling(name)
132132
assert isinstance(pooling, keras.layers.Lambda)
133-
import numpy as np
134133

135-
sample = np.array([[1, 2], [3, 4]])
134+
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
136135
result = pooling.call(sample)
137-
np.testing.assert_allclose(result, func(sample, axis=-2))
136+
assert_allclose(result, func(sample, axis=-2))
138137

139138

140139
def test_find_pooling_learnable(monkeypatch):

0 commit comments

Comments
 (0)