|
6 | 6 | serialize_keras_object as serialize, |
7 | 7 | ) |
8 | 8 |
|
9 | | -from tests.utils import allclose, assert_layers_equal |
| 9 | +from tests.utils import assert_allclose, assert_layers_equal |
10 | 10 |
|
11 | 11 |
|
12 | 12 | def test_build(inference_network, random_samples, random_conditions): |
@@ -83,58 +83,30 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_ |
83 | 83 | forward_output, conditions=random_conditions, density=True, inverse=True |
84 | 84 | ) |
85 | 85 |
|
86 | | - assert allclose(random_samples, inverse_output, atol=1e-3, rtol=1e-3) |
87 | | - assert allclose(forward_log_density, inverse_log_density, atol=1e-3, rtol=1e-3) |
| 86 | + assert_allclose(random_samples, inverse_output, atol=1e-3, rtol=1e-3) |
| 87 | + assert_allclose(forward_log_density, inverse_log_density, atol=1e-3, rtol=1e-3) |
88 | 88 |
|
89 | 89 |
|
90 | | -# TODO: make this backend-agnostic |
91 | | -@pytest.mark.torch |
92 | 90 | def test_density_numerically(generative_inference_network, random_samples, random_conditions): |
93 | | - import torch |
| 91 | + from bayesflow.utils import jacobian |
94 | 92 |
|
95 | | - forward_output, forward_log_density = generative_inference_network( |
96 | | - random_samples, conditions=random_conditions, density=True |
97 | | - ) |
| 93 | + output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True) |
98 | 94 |
|
99 | 95 | def f(x): |
100 | 96 | return generative_inference_network(x, conditions=random_conditions) |
101 | 97 |
|
102 | | - numerical_forward_jacobian, *_ = torch.autograd.functional.jacobian(f, random_samples, vectorize=True) |
103 | | - |
104 | | - # TODO: torch is somehow permuted wrt keras |
105 | | - numerical_forward_log_det = [ |
106 | | - keras.ops.log(keras.ops.abs(keras.ops.det(numerical_forward_jacobian[:, i, :]))) |
107 | | - for i in range(keras.ops.shape(random_samples)[0]) |
108 | | - ] |
109 | | - numerical_forward_log_det = keras.ops.stack(numerical_forward_log_det, axis=0) |
110 | | - |
111 | | - log_prob = generative_inference_network.base_distribution.log_prob(forward_output) |
112 | | - |
113 | | - numerical_forward_log_density = log_prob + numerical_forward_log_det |
114 | | - |
115 | | - assert allclose(forward_log_density, numerical_forward_log_density, rtol=1e-4, atol=1e-5) |
116 | | - |
117 | | - inverse_output, inverse_log_density = generative_inference_network( |
118 | | - random_samples, conditions=random_conditions, density=True, inverse=True |
119 | | - ) |
120 | | - |
121 | | - def f(x): |
122 | | - return generative_inference_network(x, conditions=random_conditions, inverse=True) |
123 | | - |
124 | | - numerical_inverse_jacobian, *_ = torch.autograd.functional.jacobian(f, random_samples, vectorize=True) |
| 98 | + numerical_output, numerical_jacobian = jacobian(f, random_samples, return_output=True) |
125 | 99 |
|
126 | | - # TODO: torch is somehow permuted wrt keras |
127 | | - numerical_inverse_log_det = [ |
128 | | - keras.ops.log(keras.ops.abs(keras.ops.det(numerical_inverse_jacobian[:, i, :]))) |
129 | | - for i in range(keras.ops.shape(random_samples)[0]) |
130 | | - ] |
131 | | - numerical_inverse_log_det = keras.ops.stack(numerical_inverse_log_det, axis=0) |
| 100 | + # output should be identical, otherwise this test does not work (e.g. for stochastic networks) |
| 101 | + assert keras.ops.all(keras.ops.isclose(output, numerical_output)) |
132 | 102 |
|
133 | | - log_prob = generative_inference_network.base_distribution.log_prob(random_samples) |
| 103 | + log_prob = generative_inference_network.base_distribution.log_prob(output) |
134 | 104 |
|
135 | | - numerical_inverse_log_density = log_prob - numerical_inverse_log_det |
| 105 | + # use change of variables to compute the numerical log density |
| 106 | + numerical_log_density = log_prob + keras.ops.log(keras.ops.abs(keras.ops.det(numerical_jacobian))) |
136 | 107 |
|
137 | | - assert allclose(inverse_log_density, numerical_inverse_log_density, rtol=1e-4, atol=1e-5) |
| 108 | + # use a high tolerance because the numerical jacobian is not very accurate |
| 109 | + assert_allclose(log_density, numerical_log_density, rtol=1e-3, atol=1e-3) |
138 | 110 |
|
139 | 111 |
|
140 | 112 | def test_serialize_deserialize(inference_network_subnet, subnet, random_samples, random_conditions): |
|
0 commit comments