diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 84c011812..01ea5ad70 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -85,6 +85,7 @@ def typical_point_inference_network_subnet(): "spline_coupling_flow", "flow_matching", "free_form_flow", + "consistency_model", ], scope="function", ) @@ -106,7 +107,8 @@ def inference_network_subnet(request): @pytest.fixture( - params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow"], scope="function" + params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow", "consistency_model"], + scope="function", ) def generative_inference_network(request): return request.getfixturevalue(request.param) diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 880a4d082..7766b29f9 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -36,13 +36,21 @@ def test_variable_batch_size(inference_network, random_samples, random_condition else: new_conditions = keras.ops.zeros((bs,) + keras.ops.shape(random_conditions)[1:]) - inference_network(new_input, conditions=new_conditions) + try: + inference_network(new_input, conditions=new_conditions) + except NotImplementedError: + # network is not invertible + pass inference_network(new_input, conditions=new_conditions, inverse=True) @pytest.mark.parametrize("density", [True, False]) def test_output_structure(density, generative_inference_network, random_samples, random_conditions): - output = generative_inference_network(random_samples, conditions=random_conditions, density=density) + try: + output = generative_inference_network(random_samples, conditions=random_conditions, density=density) + except NotImplementedError: + # network not invertible + return if density: assert isinstance(output, tuple) @@ -57,9 +65,13 @@ def test_output_structure(density, generative_inference_network, random_samples, def test_output_shape(generative_inference_network, random_samples, random_conditions): - forward_output, forward_log_density = generative_inference_network( - random_samples, conditions=random_conditions, density=True - ) + try: + forward_output, forward_log_density = generative_inference_network( + random_samples, conditions=random_conditions, density=True + ) + except NotImplementedError: + # network is not invertible, not forward function available + return assert keras.ops.shape(forward_output) == keras.ops.shape(random_samples) assert keras.ops.shape(forward_log_density) == (keras.ops.shape(random_samples)[0],) @@ -74,9 +86,13 @@ def test_output_shape(generative_inference_network, random_samples, random_condi def test_cycle_consistency(generative_inference_network, random_samples, random_conditions): # cycle-consistency means the forward and inverse methods are inverses of each other - forward_output, forward_log_density = generative_inference_network( - random_samples, conditions=random_conditions, density=True - ) + try: + forward_output, forward_log_density = generative_inference_network( + random_samples, conditions=random_conditions, density=True + ) + except NotImplementedError: + # network is not invertible, cycle consistency cannot be tested. + return inverse_output, inverse_log_density = generative_inference_network( forward_output, conditions=random_conditions, density=True, inverse=True ) @@ -88,7 +104,11 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_ def test_density_numerically(generative_inference_network, random_samples, random_conditions): from bayesflow.utils import jacobian - output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True) + try: + output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True) + except NotImplementedError: + # network does not support density estimation + return def f(x): return generative_inference_network(x, conditions=random_conditions)