Skip to content

Commit 90aa74d

Browse files
committed
fix test
1 parent 703ac2d commit 90aa74d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/test_networks/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import pytest
22

3-
from bayesflow.networks.sequential import Sequential
43
from bayesflow.networks import MLP
54
from bayesflow.utils.tensor_utils import concatenate_valid
5+
import keras
66

77

88
@pytest.fixture()
@@ -17,9 +17,10 @@ def diffusion_model_edm_F():
1717
)
1818

1919

20-
class ConcatenateMLP(Sequential):
20+
class ConcatenateMLP(keras.Layer):
2121
def __init__(self, widths):
2222
super().__init__()
23+
self.widths = widths
2324
self.mlp = MLP(widths)
2425

2526
def call(self, x, t, conditions=None, training=False):
@@ -227,6 +228,7 @@ def inference_network_subnet(request):
227228

228229
@pytest.fixture(
229230
params=[
231+
pytest.param("diffusion_model_edm_F_subnet_concatenate"),
230232
"affine_coupling_flow",
231233
"spline_coupling_flow",
232234
"flow_matching",

0 commit comments

Comments
 (0)