Skip to content

Commit 7558acb

Browse files
committed
add test for inverse concatenation
1 parent 1983839 commit 7558acb

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

bayesflow/adapters/adapter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ def forward(
9191
stage : str, one of ["training", "validation", "inference"]
9292
The stage the function is called in.
9393
log_det_jac: bool, optional
94-
Whether to return the log determinant jacobians of the transforms.
94+
Whether to return the log determinant of the Jacobian of the transforms.
9595
**kwargs : dict
9696
Additional keyword arguments passed to each transform.
9797
9898
Returns
9999
-------
100100
dict | tuple[dict, dict]
101-
The transformed data or tuple of transformed data and log determinant jacobians.
101+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
102102
"""
103103
data = data.copy()
104104
if not log_det_jac:
@@ -125,14 +125,14 @@ def inverse(
125125
stage : str, one of ["training", "validation", "inference"]
126126
The stage the function is called in.
127127
log_det_jac: bool, optional
128-
Whether to return the log determinant jacobians of the transforms.
128+
Whether to return the log determinant of the Jacobian of the transforms.
129129
**kwargs : dict
130130
Additional keyword arguments passed to each transform.
131131
132132
Returns
133133
-------
134134
dict | tuple[dict, dict]
135-
The transformed data or tuple of transformed data and log determinant jacobians.
135+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
136136
"""
137137
data = data.copy()
138138
if not log_det_jac:
@@ -166,7 +166,7 @@ def __call__(
166166
Returns
167167
-------
168168
dict | tuple[dict, dict]
169-
The transformed data or tuple of transformed data and log determinant jacobians.
169+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
170170
"""
171171
if inverse:
172172
return self.inverse(data, stage=stage, **kwargs)

bayesflow/adapters/transforms/concatenate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def log_det_jac(
131131
if inverse:
132132
if log_det_jac.get(self.into) is not None:
133133
raise ValueError(
134-
"Cannot obtain an inverse jacobian determinant of concatenation. "
134+
"Cannot obtain an inverse Jacobian of concatenation. "
135135
"Transform your variables before you concatenate."
136136
)
137137

bayesflow/adapters/transforms/numpy_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,4 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
7474
return self._inverse(data)
7575

7676
def log_det_jac(self, data, inverse=False, **kwargs):
77-
raise NotImplementedError("Log determinand jacobian of the numpy transforms are not implemented yet")
77+
raise NotImplementedError("log determinant of the Jacobian of the numpy transforms are not implemented yet")

tests/test_adapters/test_adapters.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,29 @@ def test_log_det_jac_inverse(adapter_log_det_jac_inverse, random_data):
263263

264264
for key in forward_log_det_jac.keys():
265265
assert np.allclose(forward_log_det_jac[key], -inverse_log_det_jac[key])
266+
267+
268+
def test_log_det_jac_exceptions(random_data):
269+
# Test cannot compute inverse log_det_jac
270+
# e.g., when we apply a concat and then a transform that
271+
# we cannot "unconcatenate" the log_det_jac
272+
# (because the log_det_jac are summed, not concatenated)
273+
adapter = bf.Adapter().concatenate(["p1", "p2", "p3"], into="p").sqrt("p")
274+
transformed_data, log_det_jac = adapter(random_data, log_det_jac=True)
275+
276+
# test that inverse raises error
277+
with pytest.raises(ValueError):
278+
adapter(transformed_data, inverse=True, log_det_jac=True)
279+
280+
# test resolvable order: first transform, then concatenate
281+
adapter = bf.Adapter().sqrt(["p1", "p2", "p3"]).concatenate(["p1", "p2", "p3"], into="p")
282+
283+
transformed_data, forward_log_det_jac = adapter(random_data, log_det_jac=True)
284+
data, inverse_log_det_jac = adapter(transformed_data, inverse=True, log_det_jac=True)
285+
inverse_log_det_jac = sum(inverse_log_det_jac.values())
286+
287+
# forward is the same regardless
288+
assert np.allclose(forward_log_det_jac["p"], log_det_jac["p"])
289+
290+
# inverse works when concatenation is used after transforms
291+
assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)

0 commit comments

Comments
 (0)