Skip to content

Commit b9b00a4

Browse files
committed
jacobian -> log_det_jac
1 parent 399c6dc commit b9b00a4

File tree

7 files changed

+31
-30
lines changed

7 files changed

+31
-30
lines changed

bayesflow/adapters/adapter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def get_config(self) -> dict:
8080
return serialize(config)
8181

8282
def forward(
83-
self, data: dict[str, any], *, stage: str = "inference", jacobian: bool = False, **kwargs
83+
self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
8484
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
8585
"""Apply the transforms in the forward direction.
8686
@@ -90,18 +90,18 @@ def forward(
9090
The data to be transformed.
9191
stage : str, one of ["training", "validation", "inference"]
9292
The stage the function is called in.
93-
jacobian: bool, optional
93+
log_det_jac: bool, optional
9494
Whether to return the log determinant jacobians of the transforms.
9595
**kwargs : dict
9696
Additional keyword arguments passed to each transform.
9797
9898
Returns
9999
-------
100100
dict | tuple[dict, dict]
101-
The transformed data or tuple of transformed data and jacobians.
101+
The transformed data or tuple of transformed data and log determinant jacobians.
102102
"""
103103
data = data.copy()
104-
if not jacobian:
104+
if not log_det_jac:
105105
for transform in self.transforms:
106106
data = transform(data, stage=stage, **kwargs)
107107
return data
@@ -114,7 +114,7 @@ def forward(
114114
return data, log_det_jac
115115

116116
def inverse(
117-
self, data: dict[str, np.ndarray], *, stage: str = "inference", jacobian: bool = False, **kwargs
117+
self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
118118
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
119119
"""Apply the transforms in the inverse direction.
120120
@@ -124,18 +124,18 @@ def inverse(
124124
The data to be transformed.
125125
stage : str, one of ["training", "validation", "inference"]
126126
The stage the function is called in.
127-
jacobian: bool, optional
127+
log_det_jac: bool, optional
128128
Whether to return the log determinant jacobians of the transforms.
129129
**kwargs : dict
130130
Additional keyword arguments passed to each transform.
131131
132132
Returns
133133
-------
134134
dict | tuple[dict, dict]
135-
The transformed data or tuple of transformed data and jacobians.
135+
The transformed data or tuple of transformed data and log determinant jacobians.
136136
"""
137137
data = data.copy()
138-
if not jacobian:
138+
if not log_det_jac:
139139
for transform in reversed(self.transforms):
140140
data = transform(data, stage=stage, inverse=True, **kwargs)
141141
return data
@@ -166,7 +166,7 @@ def __call__(
166166
Returns
167167
-------
168168
dict | tuple[dict, dict]
169-
The transformed data or tuple of transformed data and jacobians.
169+
The transformed data or tuple of transformed data and log determinant jacobians.
170170
"""
171171
if inverse:
172172
return self.inverse(data, stage=stage, **kwargs)

bayesflow/adapters/transforms/concatenate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def log_det_jac(
131131
if inverse:
132132
if log_det_jac.get(self.into) is not None:
133133
raise ValueError(
134-
"Cannot obtain an inverse jacobian of concatenation. "
134+
"Cannot obtain an inverse jacobian determinant of concatenation. "
135135
"Transform your variables before you concatenate."
136136
)
137137

bayesflow/adapters/transforms/numpy_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,4 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
7474
return self._inverse(data)
7575

7676
def log_det_jac(self, data, inverse=False, **kwargs):
77-
raise NotImplementedError("Jacobian of the numpy transforms are not implemented yet")
77+
raise NotImplementedError("Log determinand jacobian of the numpy transforms are not implemented yet")

bayesflow/adapters/transforms/sqrt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_config(self) -> dict:
2424
return {}
2525

2626
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
27-
ldj = -0.5 * np.log(data) + 0.5
27+
ldj = -0.5 * np.log(data) - np.log(2)
2828
if inverse:
2929
ldj = -ldj
3030
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/approximators/continuous_approximator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,14 +417,15 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic
417417
np.ndarray
418418
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
419419
"""
420-
data, jacobian = self.adapter(data, strict=False, stage="inference", jacobian=True, **kwargs)
420+
data, log_det_jac = self.adapter(data, strict=False, stage="inference", log_det_jac=True, **kwargs)
421421
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
422422
log_prob = self._log_prob(**data, **kwargs)
423423
log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob)
424424

425-
jacobian = jacobian.get("inference_variables")
426-
if jacobian is not None:
427-
log_prob = log_prob + jacobian
425+
# change of variables formula
426+
log_det_jac = log_det_jac.get("inference_variables")
427+
if log_det_jac is not None:
428+
log_prob = log_prob + log_det_jac
428429

429430
return log_prob
430431

tests/test_adapters/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def random_data():
6464

6565

6666
@pytest.fixture()
67-
def adapter_jacobian():
67+
def adapter_log_det_jac():
6868
from bayesflow.adapters import Adapter
6969

7070
adapter = (
@@ -83,7 +83,7 @@ def adapter_jacobian():
8383

8484

8585
@pytest.fixture()
86-
def adapter_jacobian_inverse():
86+
def adapter_log_det_jac_inverse():
8787
from bayesflow.adapters import Adapter
8888

8989
adapter = (

tests/test_adapters/test_adapters.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,34 +232,34 @@ def test_to_dict_transform():
232232
assert processed["category"].shape[-1] == 5
233233

234234

235-
def test_jacobian(adapter_jacobian, random_data):
236-
d, jacobian = adapter_jacobian(random_data, jacobian=True)
235+
def test_log_det_jac(adapter_log_det_jac, random_data):
236+
d, log_det_jac = adapter_log_det_jac(random_data, log_det_jac=True)
237237

238-
assert np.allclose(jacobian["x1"], np.log(2))
238+
assert np.allclose(log_det_jac["x1"], np.log(2))
239239

240240
p1 = -np.log1p(random_data["p1"])
241-
p2 = -0.5 * np.log(random_data["p2"]) + 0.5
241+
p2 = -0.5 * np.log(random_data["p2"]) - np.log(2)
242242
p3 = random_data["p3"] - np.log(np.exp(random_data["p3"]) - 1)
243243
p = np.sum(p1, axis=-1) + np.sum(p2, axis=-1) + np.sum(p3, axis=-1)
244244

245-
assert np.allclose(jacobian["p"], p)
245+
assert np.allclose(log_det_jac["p"], p)
246246

247247
n1 = -(random_data["n1"] - 1)
248248
n1 = n1 - np.log(np.exp(n1) - 1)
249249
n1 = np.sum(n1, axis=-1)
250250

251-
assert np.allclose(jacobian["n1"], n1)
251+
assert np.allclose(log_det_jac["n1"], n1)
252252

253253
u1 = random_data["u1"]
254254
u1 = (u1 + 1) / 3
255255
u1 = -np.log(u1) - np.log1p(-u1) - np.log(3)
256256

257-
assert np.allclose(jacobian["u"], u1[:, 0])
257+
assert np.allclose(log_det_jac["u"], u1[:, 0])
258258

259259

260-
def test_jacobian_inverse(adapter_jacobian_inverse, random_data):
261-
d, forward_jacobian = adapter_jacobian_inverse(random_data, jacobian=True)
262-
d, inverse_jacobian = adapter_jacobian_inverse(d, inverse=True, jacobian=True)
260+
def test_log_det_jac_inverse(adapter_log_det_jac_inverse, random_data):
261+
d, forward_log_det_jac = adapter_log_det_jac_inverse(random_data, log_det_jac=True)
262+
d, inverse_log_det_jac = adapter_log_det_jac_inverse(d, inverse=True, log_det_jac=True)
263263

264-
for key in forward_jacobian.keys():
265-
assert np.allclose(forward_jacobian[key], -inverse_jacobian[key])
264+
for key in forward_log_det_jac.keys():
265+
assert np.allclose(forward_log_det_jac[key], -inverse_log_det_jac[key])

0 commit comments

Comments
 (0)