Skip to content

Commit e2ade26

Browse files
committed
test improvements
1 parent 35385df commit e2ade26

File tree

1 file changed

+13
-41
lines changed

1 file changed

+13
-41
lines changed

tests/test_networks/test_inference_networks.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
serialize_keras_object as serialize,
77
)
88

9-
from tests.utils import allclose, assert_layers_equal
9+
from tests.utils import assert_allclose, assert_layers_equal
1010

1111

1212
def test_build(inference_network, random_samples, random_conditions):
@@ -83,58 +83,30 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_
8383
forward_output, conditions=random_conditions, density=True, inverse=True
8484
)
8585

86-
assert allclose(random_samples, inverse_output, atol=1e-3, rtol=1e-3)
87-
assert allclose(forward_log_density, inverse_log_density, atol=1e-3, rtol=1e-3)
86+
assert_allclose(random_samples, inverse_output, atol=1e-3, rtol=1e-3)
87+
assert_allclose(forward_log_density, inverse_log_density, atol=1e-3, rtol=1e-3)
8888

8989

90-
# TODO: make this backend-agnostic
91-
@pytest.mark.torch
9290
def test_density_numerically(generative_inference_network, random_samples, random_conditions):
93-
import torch
91+
from bayesflow.utils import jacobian
9492

95-
forward_output, forward_log_density = generative_inference_network(
96-
random_samples, conditions=random_conditions, density=True
97-
)
93+
output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True)
9894

9995
def f(x):
10096
return generative_inference_network(x, conditions=random_conditions)
10197

102-
numerical_forward_jacobian, *_ = torch.autograd.functional.jacobian(f, random_samples, vectorize=True)
103-
104-
# TODO: torch is somehow permuted wrt keras
105-
numerical_forward_log_det = [
106-
keras.ops.log(keras.ops.abs(keras.ops.det(numerical_forward_jacobian[:, i, :])))
107-
for i in range(keras.ops.shape(random_samples)[0])
108-
]
109-
numerical_forward_log_det = keras.ops.stack(numerical_forward_log_det, axis=0)
110-
111-
log_prob = generative_inference_network.base_distribution.log_prob(forward_output)
112-
113-
numerical_forward_log_density = log_prob + numerical_forward_log_det
114-
115-
assert allclose(forward_log_density, numerical_forward_log_density, rtol=1e-4, atol=1e-5)
116-
117-
inverse_output, inverse_log_density = generative_inference_network(
118-
random_samples, conditions=random_conditions, density=True, inverse=True
119-
)
120-
121-
def f(x):
122-
return generative_inference_network(x, conditions=random_conditions, inverse=True)
123-
124-
numerical_inverse_jacobian, *_ = torch.autograd.functional.jacobian(f, random_samples, vectorize=True)
98+
numerical_output, numerical_jacobian = jacobian(f, random_samples, return_output=True)
12599

126-
# TODO: torch is somehow permuted wrt keras
127-
numerical_inverse_log_det = [
128-
keras.ops.log(keras.ops.abs(keras.ops.det(numerical_inverse_jacobian[:, i, :])))
129-
for i in range(keras.ops.shape(random_samples)[0])
130-
]
131-
numerical_inverse_log_det = keras.ops.stack(numerical_inverse_log_det, axis=0)
100+
# output should be identical, otherwise this test does not work (e.g. for stochastic networks)
101+
assert keras.ops.all(keras.ops.isclose(output, numerical_output))
132102

133-
log_prob = generative_inference_network.base_distribution.log_prob(random_samples)
103+
log_prob = generative_inference_network.base_distribution.log_prob(output)
134104

135-
numerical_inverse_log_density = log_prob - numerical_inverse_log_det
105+
# use change of variables to compute the numerical log density
106+
numerical_log_density = log_prob + keras.ops.log(keras.ops.abs(keras.ops.det(numerical_jacobian)))
136107

137-
assert allclose(inverse_log_density, numerical_inverse_log_density, rtol=1e-4, atol=1e-5)
108+
# use a high tolerance because the numerical jacobian is not very accurate
109+
assert_allclose(log_density, numerical_log_density, rtol=1e-3, atol=1e-3)
138110

139111

140112
def test_serialize_deserialize(inference_network_subnet, subnet, random_samples, random_conditions):

0 commit comments

Comments
 (0)