diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py index c7b528987..ee78f180e 100644 --- a/bayesflow/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/networks/coupling_flow/coupling_flow.py @@ -77,7 +77,7 @@ def __init__( The type of transformation used in the coupling layers, such as "affine". Default is "affine". permutation : str or None, optional - The type of permutation applied between layers. Can be "random" or None + The type of permutation applied between layers. Can be "orthogonal", "random", "swap", or None (no permutation). Default is "random". use_actnorm : bool, optional Whether to apply ActNorm before each coupling layer. Default is True. diff --git a/tests/test_networks/test_coupling_flow/test_permutations.py b/tests/test_networks/test_coupling_flow/test_permutations.py new file mode 100644 index 000000000..63c50ae2c --- /dev/null +++ b/tests/test_networks/test_coupling_flow/test_permutations.py @@ -0,0 +1,117 @@ +import pytest +import keras +import numpy as np + +from bayesflow.networks.coupling_flow.permutations import ( + FixedPermutation, + OrthogonalPermutation, + RandomPermutation, + Swap, +) + + +@pytest.fixture(params=[FixedPermutation, OrthogonalPermutation, RandomPermutation, Swap]) +def permutation_class(request): + return request.param + + +@pytest.fixture +def input_tensor(): + return keras.random.normal((2, 5)) + + +def test_fixed_permutation_build_and_call(): + # Since FixedPermutation is abstract, create a subclass for testing build. + class TestPerm(FixedPermutation): + def build(self, xz_shape, **kwargs): + length = xz_shape[-1] + self.forward_indices = keras.ops.arange(length - 1, -1, -1) + self.inverse_indices = keras.ops.arange(length - 1, -1, -1) + + layer = TestPerm() + input_shape = (2, 4) + layer.build(input_shape) + + x = keras.ops.convert_to_tensor(np.arange(8).reshape(input_shape).astype("float32")) + z, log_det = layer(x, inverse=False) + x_inv, log_det_inv = layer(z, inverse=True) + + # Check shape preservation + assert z.shape == x.shape + assert x_inv.shape == x.shape + # Forward then inverse recovers input + np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(x), atol=1e-5) + # log_det values should be zero tensors with the correct shape + assert tuple(log_det.shape) == input_shape[:-1] + assert tuple(log_det_inv.shape) == input_shape[:-1] + + +def test_orthogonal_permutation_build_and_call(input_tensor): + layer = OrthogonalPermutation() + input_shape = keras.ops.shape(input_tensor) + layer.build(input_shape) + + z, log_det = layer(input_tensor) + x_inv, log_det_inv = layer(z, inverse=True) + + # Check output shapes + assert z.shape == input_tensor.shape + assert x_inv.shape == input_tensor.shape + + # Forward + inverse should approximately recover input (allow some numeric tolerance) + np.testing.assert_allclose( + keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), rtol=1e-5, atol=1e-5 + ) + + # log_det should be scalar or batched scalar + if len(log_det.shape) > 0: + assert log_det.shape[0] == input_tensor.shape[0] # batch dim + else: + assert log_det.shape == () + + # log_det_inv should be negative of log_det (det(inv) = 1/det) + log_det_np = keras.ops.convert_to_numpy(log_det) + log_det_inv_np = keras.ops.convert_to_numpy(log_det_inv) + np.testing.assert_allclose(log_det_inv_np, -log_det_np, rtol=1e-5, atol=1e-5) + + +def test_random_permutation_build_and_call(input_tensor): + layer = RandomPermutation() + input_shape = keras.ops.shape(input_tensor) + layer.build(input_shape) + + # Assert forward_indices and inverse_indices are set and consistent + fwd = keras.ops.convert_to_numpy(layer.forward_indices) + inv = keras.ops.convert_to_numpy(layer.inverse_indices) + # Applying inv on fwd must yield ordered indices + reordered = fwd[inv] + np.testing.assert_array_equal(np.arange(len(fwd)), reordered) + + z, log_det = layer(input_tensor) + x_inv, log_det_inv = layer(z, inverse=True) + + assert z.shape == input_tensor.shape + assert x_inv.shape == input_tensor.shape + np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), atol=1e-5) + assert tuple(log_det.shape) == input_shape[:-1] + assert tuple(log_det_inv.shape) == input_shape[:-1] + + +def test_swap_build_and_call(input_tensor): + layer = Swap() + input_shape = keras.ops.shape(input_tensor) + layer.build(input_shape) + + fwd = keras.ops.convert_to_numpy(layer.forward_indices) + inv = keras.ops.convert_to_numpy(layer.inverse_indices) + reordered = fwd[inv] + np.testing.assert_array_equal(np.arange(len(fwd)), reordered) + + z, log_det = layer(input_tensor) + x_inv, log_det_inv = layer(z, inverse=True) + + assert z.shape == input_tensor.shape + assert x_inv.shape == input_tensor.shape + np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), atol=1e-5) + assert tuple(log_det.shape) == input_shape[:-1] + assert tuple(log_det_inv.shape) == input_shape[:-1] diff --git a/tests/test_networks/test_embeddings.py b/tests/test_networks/test_embeddings.py new file mode 100644 index 000000000..7385d94c0 --- /dev/null +++ b/tests/test_networks/test_embeddings.py @@ -0,0 +1,85 @@ +import pytest +import keras + +from bayesflow.networks.embeddings import ( + FourierEmbedding, + RecurrentEmbedding, + Time2Vec, +) + + +def test_fourier_embedding_output_shape_and_type(): + embed_dim = 8 + batch_size = 4 + + emb_layer = FourierEmbedding(embed_dim=embed_dim, include_identity=True) + # use keras.ops.zeros with shape (batch_size, 1) and float32 dtype + t = keras.ops.zeros((batch_size, 1), dtype="float32") + + emb = emb_layer(t) + # Expected shape is (batch_size, embed_dim + 1) if include_identity else (batch_size, embed_dim) + expected_dim = embed_dim + 1 + assert emb.shape[0] == batch_size + assert emb.shape[1] == expected_dim + # Check type - it should be a Keras tensor, convert to numpy for checking + np_emb = keras.ops.convert_to_numpy(emb) + assert np_emb.shape == (batch_size, expected_dim) + + +def test_fourier_embedding_without_identity(): + embed_dim = 8 + batch_size = 3 + + emb_layer = FourierEmbedding(embed_dim=embed_dim, include_identity=False) + t = keras.ops.zeros((batch_size, 1), dtype="float32") + + emb = emb_layer(t) + expected_dim = embed_dim + assert emb.shape[0] == batch_size + assert emb.shape[1] == expected_dim + + +def test_fourier_embedding_raises_for_odd_embed_dim(): + with pytest.raises(ValueError): + FourierEmbedding(embed_dim=7) + + +def test_recurrent_embedding_lstm_and_gru_shapes(): + batch_size = 2 + seq_len = 5 + dim = 3 + embed_dim = 6 + + # Dummy input + x = keras.ops.zeros((batch_size, seq_len, dim), dtype="float32") + + # lstm + lstm_layer = RecurrentEmbedding(embed_dim=embed_dim, embedding="lstm") + emb_lstm = lstm_layer(x) + # Check the concatenated shape: last dimension = original dim + embed_dim + assert emb_lstm.shape == (batch_size, seq_len, dim + embed_dim) + + # gru + gru_layer = RecurrentEmbedding(embed_dim=embed_dim, embedding="gru") + emb_gru = gru_layer(x) + assert emb_gru.shape == (batch_size, seq_len, dim + embed_dim) + + +def test_recurrent_embedding_raises_unknown_embedding(): + with pytest.raises(ValueError): + RecurrentEmbedding(embed_dim=4, embedding="unknown") + + +def test_time2vec_shapes_and_output(): + batch_size = 3 + seq_len = 7 + dim = 2 + num_periodic_features = 4 + + x = keras.ops.zeros((batch_size, seq_len, dim), dtype="float32") + time2vec_layer = Time2Vec(num_periodic_features=num_periodic_features) + + emb = time2vec_layer(x) + # The last dimension should be dim + num_periodic_features + 1 (trend + periodic) + expected_dim = dim + num_periodic_features + 1 + assert emb.shape == (batch_size, seq_len, expected_dim)