Skip to content

Commit 15890fe

Browse files
committed
Add interleaved test
1 parent f35f6a1 commit 15890fe

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

tests/test_inference_networks.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,28 @@
2929
@pytest.mark.parametrize("input_shape", ["2d", "3d"])
3030
@pytest.mark.parametrize("use_soft_flow", [True, False])
3131
@pytest.mark.parametrize("permutation", ["learnable", "fixed"])
32-
@pytest.mark.parametrize("coupling_design", ["affine", "spline"])
33-
@pytest.mark.parametrize("num_coupling_layers", [2, 8])
32+
@pytest.mark.parametrize("coupling_design", ["affine", "spline", "interleaved"])
33+
@pytest.mark.parametrize("num_coupling_layers", [2, 7])
3434
def test_invertible_network(input_shape, use_soft_flow, permutation, coupling_design, num_coupling_layers):
3535
"""Tests the ``InvertibleNetwork`` core class using a couple of relevant configurations."""
3636

3737
# Randomize units and input dim
3838
units = np.random.randint(low=2, high=32)
3939
input_dim = np.random.randint(low=2, high=32)
4040

41-
# Create settings dictionaries and network
42-
coupling_settings = {
43-
"dense_args": dict(units=units, activation="elu"),
44-
"num_dense": 1,
45-
}
41+
# Create settings dictionaries
42+
if coupling_design in ["affine", "spline"]:
43+
coupling_settings = {
44+
"dense_args": dict(units=units, activation="elu"),
45+
"num_dense": 1,
46+
}
47+
else:
48+
coupling_settings = {
49+
"affine": dict(dense_args={"units": units, "activation": "selu"}, num_dense=1),
50+
"spline": dict(dense_args={"units": units, "activation": "relu"}, bins=8, num_dense=1),
51+
}
4652

53+
# Create invertible network with test settings
4754
network = InvertibleNetwork(
4855
num_params=input_dim,
4956
num_coupling_layers=num_coupling_layers,
@@ -72,7 +79,7 @@ def test_invertible_network(input_shape, use_soft_flow, permutation, coupling_de
7279
assert network.latent_dim == input_dim
7380
assert len(network.coupling_layers) == num_coupling_layers
7481
# Test layer attributes
75-
for l in network.coupling_layers:
82+
for idx, l in enumerate(network.coupling_layers):
7683
# Permutation
7784
if permutation == "fixed":
7885
assert isinstance(l.permutation, Permutation)
@@ -85,6 +92,11 @@ def test_invertible_network(input_shape, use_soft_flow, permutation, coupling_de
8592
assert isinstance(l.net1, AffineCoupling) and isinstance(l.net2, AffineCoupling)
8693
elif coupling_design == "spline":
8794
assert isinstance(l.net1, SplineCoupling) and isinstance(l.net2, SplineCoupling)
95+
elif coupling_design == "interleaved":
96+
if idx % 2 == 0:
97+
assert isinstance(l.net1, AffineCoupling) and isinstance(l.net2, AffineCoupling)
98+
else:
99+
assert isinstance(l.net1, SplineCoupling) and isinstance(l.net2, SplineCoupling)
88100

89101
if use_soft_flow:
90102
assert network.soft_flow is True

0 commit comments

Comments
 (0)