diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 4db738eef..ab6800d8a 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -79,7 +79,9 @@ def get_config(self) -> dict: return serialize(config) - def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]: + def forward( + self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the forward direction. Parameters @@ -88,22 +90,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) - The data to be transformed. stage : str, one of ["training", "validation", "inference"] The stage the function is called in. + log_det_jac: bool, optional + Whether to return the log determinant of the Jacobian of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ data = data.copy() + if not log_det_jac: + for transform in self.transforms: + data = transform(data, stage=stage, **kwargs) + return data + log_det_jac = {} for transform in self.transforms: - data = transform(data, stage=stage, **kwargs) + transformed_data = transform(data, stage=stage, **kwargs) + log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs) + data = transformed_data - return data + return data, log_det_jac - def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]: + def inverse( + self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the inverse direction. Parameters @@ -112,24 +125,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw The data to be transformed. stage : str, one of ["training", "validation", "inference"] The stage the function is called in. + log_det_jac: bool, optional + Whether to return the log determinant of the Jacobian of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ data = data.copy() + if not log_det_jac: + for transform in reversed(self.transforms): + data = transform(data, stage=stage, inverse=True, **kwargs) + return data + log_det_jac = {} for transform in reversed(self.transforms): data = transform(data, stage=stage, inverse=True, **kwargs) + log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs) - return data + return data, log_det_jac def __call__( self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs - ) -> dict[str, np.ndarray]: + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the given direction. Parameters @@ -145,8 +166,8 @@ def __call__( Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ if inverse: return self.inverse(data, stage=stage, **kwargs) diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index deb54fc3f..91ea9178b 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -115,3 +115,37 @@ def extra_repr(self) -> str: result += f", axis={self.axis}" return result + + def log_det_jac( + self, + data: dict[str, np.ndarray], + log_det_jac: dict[str, np.ndarray], + *, + strict: bool = False, + inverse: bool = False, + **kwargs, + ) -> dict[str, np.ndarray]: + # copy to avoid side effects + log_det_jac = log_det_jac.copy() + + if inverse: + if log_det_jac.get(self.into) is not None: + raise ValueError( + "Cannot obtain an inverse Jacobian of concatenation. " + "Transform your variables before you concatenate." + ) + + return log_det_jac + + required_keys = set(self.keys) + available_keys = set(log_det_jac.keys()) + common_keys = available_keys & required_keys + + if len(common_keys) == 0: + return log_det_jac + + parts = [log_det_jac.pop(key) for key in common_keys] + + log_det_jac[self.into] = sum(parts) + + return log_det_jac diff --git a/bayesflow/adapters/transforms/constrain.py b/bayesflow/adapters/transforms/constrain.py index 5f93135a1..a4ca0be25 100644 --- a/bayesflow/adapters/transforms/constrain.py +++ b/bayesflow/adapters/transforms/constrain.py @@ -87,6 +87,11 @@ def constrain(x): def unconstrain(x): return inverse_sigmoid((x - lower) / (upper - lower)) + + def ldj(x): + x = (x - lower) / (upper - lower) + return -np.log(x) - np.log1p(-x) - np.log(upper - lower) + case str() as name: raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.") case other: @@ -101,6 +106,11 @@ def constrain(x): def unconstrain(x): return inverse_softplus(x - lower) + + def ldj(x): + x = x - lower + return x - np.log(np.exp(x) - 1) + case "exp" | "log": def constrain(x): @@ -108,6 +118,10 @@ def constrain(x): def unconstrain(x): return np.log(x - lower) + + def ldj(x): + return -np.log(x - lower) + case str() as name: raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.") case other: @@ -122,6 +136,11 @@ def constrain(x): def unconstrain(x): return -inverse_softplus(-(x - upper)) + + def ldj(x): + x = -(x - upper) + return x - np.log(np.exp(x) - 1) + case "exp" | "log": def constrain(x): @@ -129,6 +148,9 @@ def constrain(x): def unconstrain(x): return -np.log(-x + upper) + + def ldj(x): + return -np.log(-x + upper) case str() as name: raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.") case other: @@ -142,6 +164,7 @@ def unconstrain(x): self.constrain = constrain self.unconstrain = unconstrain + self.ldj = ldj # do this last to avoid serialization issues match inclusive: @@ -178,3 +201,9 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: # inverse means network space -> data space, so constrain the data return self.constrain(data) + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + ldj = self.ldj(data) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/drop.py b/bayesflow/adapters/transforms/drop.py index 51615d632..91dcd6a28 100644 --- a/bayesflow/adapters/transforms/drop.py +++ b/bayesflow/adapters/transforms/drop.py @@ -46,3 +46,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: def extra_repr(self) -> str: return "[" + ", ".join(map(repr, self.keys)) + "]" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac) diff --git a/bayesflow/adapters/transforms/elementwise_transform.py b/bayesflow/adapters/transforms/elementwise_transform.py index 3bde5a1da..7d603d517 100644 --- a/bayesflow/adapters/transforms/elementwise_transform.py +++ b/bayesflow/adapters/transforms/elementwise_transform.py @@ -25,3 +25,6 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: raise NotImplementedError + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray | None: + return None diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index e1920e73c..7eccf370b 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -150,9 +150,35 @@ def _should_transform(self, key: str, value: np.ndarray, inverse: bool = False) return predicate(key, value, inverse=inverse) def _apply_transform(self, key: str, value: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + transform = self._get_transform(key) + + return transform(value, inverse=inverse, **kwargs) + + def _get_transform(self, key: str) -> ElementwiseTransform: if key not in self.transform_map: self.transform_map[key] = self.transform_constructor(**self.kwargs) - transform = self.transform_map[key] + return self.transform_map[key] - return transform(value, inverse=inverse, **kwargs) + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs + ): + data = data.copy() + + if strict and self.include is not None: + missing_keys = set(self.include) - set(data.keys()) + if missing_keys: + raise KeyError(f"Missing keys from include list: {missing_keys!r}") + + for key, value in data.items(): + if self._should_transform(key, value, inverse=False): + transform = self._get_transform(key) + ldj = transform.log_det_jac(value, **kwargs) + if ldj is None: + continue + elif key in log_det_jac: + log_det_jac[key] += ldj + else: + log_det_jac[key] = ldj + + return log_det_jac diff --git a/bayesflow/adapters/transforms/keep.py b/bayesflow/adapters/transforms/keep.py index 62373071f..56f395166 100644 --- a/bayesflow/adapters/transforms/keep.py +++ b/bayesflow/adapters/transforms/keep.py @@ -57,3 +57,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: def extra_repr(self) -> str: return "[" + ", ".join(map(repr, self.keys)) + "]" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac) diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index 3184ab979..d5f559b4f 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -37,3 +37,12 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def get_config(self) -> dict: return serialize({"p1": self.p1}) + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + if self.p1: + ldj = -np.log1p(data) + else: + ldj = -np.log(data) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/map_transform.py b/bayesflow/adapters/transforms/map_transform.py index 7820ce611..5da8292af 100644 --- a/bayesflow/adapters/transforms/map_transform.py +++ b/bayesflow/adapters/transforms/map_transform.py @@ -41,12 +41,8 @@ def get_config(self) -> dict: def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) -> dict[str, np.ndarray]: data = data.copy() - required_keys = set(self.transform_map.keys()) - available_keys = set(data.keys()) - missing_keys = required_keys - available_keys - - if strict and missing_keys: - raise KeyError(f"Missing keys: {missing_keys!r}") + if strict: + self._check_keys(data) for key, transform in self.transform_map.items(): if key in data: @@ -57,15 +53,40 @@ def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) def inverse(self, data: dict[str, np.ndarray], *, strict: bool = False, **kwargs) -> dict[str, np.ndarray]: data = data.copy() - required_keys = set(self.transform_map.keys()) - available_keys = set(data.keys()) - missing_keys = required_keys - available_keys - - if strict and missing_keys: - raise KeyError(f"Missing keys: {missing_keys!r}") + if strict: + self._check_keys(data) for key, transform in self.transform_map.items(): if key in data: data[key] = transform.inverse(data[key], **kwargs) return data + + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs + ) -> dict[str, np.ndarray]: + data = data.copy() + + if strict: + self._check_keys(data) + + for key, transform in self.transform_map.items(): + if key in data: + ldj = transform.log_det_jac(data[key], **kwargs) + + if ldj is None: + continue + elif key in log_det_jac: + log_det_jac[key] += ldj + else: + log_det_jac[key] = ldj + + return log_det_jac + + def _check_keys(self, data: dict[str, np.ndarray]): + required_keys = set(self.transform_map.keys()) + available_keys = set(data.keys()) + missing_keys = required_keys - available_keys + + if missing_keys: + raise KeyError(f"Missing keys: {missing_keys!r}") diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py index aecf03bba..29d25dc67 100644 --- a/bayesflow/adapters/transforms/numpy_transform.py +++ b/bayesflow/adapters/transforms/numpy_transform.py @@ -72,3 +72,6 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return self._inverse(data) + + def log_det_jac(self, data, inverse=False, **kwargs): + raise NotImplementedError("log determinant of the Jacobian of the numpy transforms are not implemented yet") diff --git a/bayesflow/adapters/transforms/rename.py b/bayesflow/adapters/transforms/rename.py index 49cc52eba..746ef5a80 100644 --- a/bayesflow/adapters/transforms/rename.py +++ b/bayesflow/adapters/transforms/rename.py @@ -58,3 +58,6 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di def extra_repr(self) -> str: return f"{self.from_key!r} -> {self.to_key!r}" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False) diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py index 8d9dce1be..96b2ff927 100644 --- a/bayesflow/adapters/transforms/scale.py +++ b/bayesflow/adapters/transforms/scale.py @@ -18,3 +18,10 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return data / self.scale + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + ldj = np.log(np.abs(self.scale)) + ldj = np.full(data.shape, ldj) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index 617f892bc..4ef1370dc 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -22,3 +22,9 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def get_config(self) -> dict: return {} + + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + ldj = -0.5 * np.log(data) - np.log(2) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py index 52899917c..9699819b9 100644 --- a/bayesflow/adapters/transforms/standardize.py +++ b/bayesflow/adapters/transforms/standardize.py @@ -120,3 +120,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: std = np.broadcast_to(self.std, data.shape) return data * std + mean + + def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray: + std = np.broadcast_to(self.std, data.shape) + ldj = np.log(np.abs(std)) + if inverse: + ldj = -ldj + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/transform.py b/bayesflow/adapters/transforms/transform.py index 4642c1165..ed3058e15 100644 --- a/bayesflow/adapters/transforms/transform.py +++ b/bayesflow/adapters/transforms/transform.py @@ -35,3 +35,8 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray def extra_repr(self) -> str: return "" + + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], inverse: bool = False, **kwargs + ) -> dict[str, np.ndarray]: + return log_det_jac diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index dcb661ca0..bf4e263a0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -417,11 +417,16 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic np.ndarray Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))` """ - data = self.adapter(data, strict=False, stage="inference", **kwargs) + data, log_det_jac = self.adapter(data, strict=False, stage="inference", log_det_jac=True, **kwargs) data = keras.tree.map_structure(keras.ops.convert_to_tensor, data) log_prob = self._log_prob(**data, **kwargs) log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob) + # change of variables formula + log_det_jac = log_det_jac.get("inference_variables") + if log_det_jac is not None: + log_prob = log_prob + log_det_jac + return log_prob def _log_prob( diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 873279f09..d69cd4be4 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -49,6 +49,8 @@ def random_data(): "z1": np.random.standard_normal(size=(32, 2)), "p1": np.random.lognormal(size=(32, 2)), "p2": np.random.lognormal(size=(32, 2)), + "p3": np.random.lognormal(size=(32, 2)), + "n1": 1 - np.random.lognormal(size=(32, 2)), "s1": np.random.standard_normal(size=(32, 3, 2)), "s2": np.random.standard_normal(size=(32, 3, 2)), "t1": np.zeros((3, 2)), @@ -56,5 +58,43 @@ def random_data(): "d1": np.random.standard_normal(size=(32, 2)), "d2": np.random.standard_normal(size=(32, 2)), "o1": np.random.randint(0, 9, size=(32, 2)), + "u1": np.random.uniform(low=-1, high=2, size=(32, 1)), "key_to_split": np.random.standard_normal(size=(32, 10)), } + + +@pytest.fixture() +def adapter_log_det_jac(): + from bayesflow.adapters import Adapter + + adapter = ( + Adapter() + .scale("x1", by=2) + .log("p1", p1=True) + .sqrt("p2") + .constrain("p3", lower=0) + .constrain("n1", upper=1) + .constrain("u1", lower=-1, upper=2) + .concatenate(["p1", "p2", "p3"], into="p") + .rename("u1", "u") + ) + + return adapter + + +@pytest.fixture() +def adapter_log_det_jac_inverse(): + from bayesflow.adapters import Adapter + + adapter = ( + Adapter() + .standardize("x1", mean=1, std=2) + .log("p1") + .sqrt("p2") + .constrain("p3", lower=0, method="log") + .constrain("n1", upper=1, method="log") + .constrain("u1", lower=-1, upper=2) + .scale(["p1", "p2", "p3"], by=3.5) + ) + + return adapter diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index d6215170e..1784befb7 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -13,7 +13,7 @@ def test_cycle_consistency(adapter, random_data): deprocessed = adapter(processed, inverse=True) for key, value in random_data.items(): - if key in ["d1", "d2"]: + if key in ["d1", "d2", "p3", "n1", "u1"]: # dropped continue assert key in deprocessed @@ -230,3 +230,62 @@ def test_to_dict_transform(): # category should have 5 one-hot categories, even though it was only passed 4 assert processed["category"].shape[-1] == 5 + + +def test_log_det_jac(adapter_log_det_jac, random_data): + d, log_det_jac = adapter_log_det_jac(random_data, log_det_jac=True) + + assert np.allclose(log_det_jac["x1"], np.log(2)) + + p1 = -np.log1p(random_data["p1"]) + p2 = -0.5 * np.log(random_data["p2"]) - np.log(2) + p3 = random_data["p3"] - np.log(np.exp(random_data["p3"]) - 1) + p = np.sum(p1, axis=-1) + np.sum(p2, axis=-1) + np.sum(p3, axis=-1) + + assert np.allclose(log_det_jac["p"], p) + + n1 = -(random_data["n1"] - 1) + n1 = n1 - np.log(np.exp(n1) - 1) + n1 = np.sum(n1, axis=-1) + + assert np.allclose(log_det_jac["n1"], n1) + + u1 = random_data["u1"] + u1 = (u1 + 1) / 3 + u1 = -np.log(u1) - np.log1p(-u1) - np.log(3) + + assert np.allclose(log_det_jac["u"], u1[:, 0]) + + +def test_log_det_jac_inverse(adapter_log_det_jac_inverse, random_data): + d, forward_log_det_jac = adapter_log_det_jac_inverse(random_data, log_det_jac=True) + d, inverse_log_det_jac = adapter_log_det_jac_inverse(d, inverse=True, log_det_jac=True) + + for key in forward_log_det_jac.keys(): + assert np.allclose(forward_log_det_jac[key], -inverse_log_det_jac[key]) + + +def test_log_det_jac_exceptions(random_data): + # Test cannot compute inverse log_det_jac + # e.g., when we apply a concat and then a transform that + # we cannot "unconcatenate" the log_det_jac + # (because the log_det_jac are summed, not concatenated) + adapter = bf.Adapter().concatenate(["p1", "p2", "p3"], into="p").sqrt("p") + transformed_data, log_det_jac = adapter(random_data, log_det_jac=True) + + # test that inverse raises error + with pytest.raises(ValueError): + adapter(transformed_data, inverse=True, log_det_jac=True) + + # test resolvable order: first transform, then concatenate + adapter = bf.Adapter().sqrt(["p1", "p2", "p3"]).concatenate(["p1", "p2", "p3"], into="p") + + transformed_data, forward_log_det_jac = adapter(random_data, log_det_jac=True) + data, inverse_log_det_jac = adapter(transformed_data, inverse=True, log_det_jac=True) + inverse_log_det_jac = sum(inverse_log_det_jac.values()) + + # forward is the same regardless + assert np.allclose(forward_log_det_jac["p"], log_det_jac["p"]) + + # inverse works when concatenation is used after transforms + assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)