From 1f576ef691e2339484463613a2ffa601ab615c48 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 10 Apr 2025 11:02:21 +0200 Subject: [PATCH 01/20] minimal working case (.scale) --- bayesflow/adapters/adapter.py | 16 ++++--- .../transforms/elementwise_transform.py | 3 ++ .../adapters/transforms/map_transform.py | 44 ++++++++++++++----- bayesflow/adapters/transforms/scale.py | 4 ++ 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 4db738eef..722b96db6 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -79,7 +79,9 @@ def get_config(self) -> dict: return serialize(config) - def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]: + def forward( + self, data: dict[str, any], *, stage: str = "inference", jacobian: bool = False, **kwargs + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the forward direction. Parameters @@ -97,13 +99,15 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) - The transformed data. """ data = data.copy() + log_det_jac = {} for transform in self.transforms: data = transform(data, stage=stage, **kwargs) + log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs) - return data + return data, log_det_jac - def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]: + def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", jacobian: bool = False, **kwargs) -> dict[str, any]: """Apply the transforms in the inverse direction. Parameters @@ -121,10 +125,12 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw The transformed data. """ data = data.copy() + if jacobian: + data = self._init_jacobian(data) for transform in reversed(self.transforms): - data = transform(data, stage=stage, inverse=True, **kwargs) - + data = transform(data, stage=stage, inverse=True, jacobian=jacobian, **kwargs) + return data def __call__( diff --git a/bayesflow/adapters/transforms/elementwise_transform.py b/bayesflow/adapters/transforms/elementwise_transform.py index 3bde5a1da..9809a72f8 100644 --- a/bayesflow/adapters/transforms/elementwise_transform.py +++ b/bayesflow/adapters/transforms/elementwise_transform.py @@ -25,3 +25,6 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: raise NotImplementedError + + def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray | None: + return None diff --git a/bayesflow/adapters/transforms/map_transform.py b/bayesflow/adapters/transforms/map_transform.py index 7820ce611..eabfeb927 100644 --- a/bayesflow/adapters/transforms/map_transform.py +++ b/bayesflow/adapters/transforms/map_transform.py @@ -41,12 +41,8 @@ def get_config(self) -> dict: def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) -> dict[str, np.ndarray]: data = data.copy() - required_keys = set(self.transform_map.keys()) - available_keys = set(data.keys()) - missing_keys = required_keys - available_keys - - if strict and missing_keys: - raise KeyError(f"Missing keys: {missing_keys!r}") + if strict: + self._check_keys(data) for key, transform in self.transform_map.items(): if key in data: @@ -57,15 +53,39 @@ def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) def inverse(self, data: dict[str, np.ndarray], *, strict: bool = False, **kwargs) -> dict[str, np.ndarray]: data = data.copy() - required_keys = set(self.transform_map.keys()) - available_keys = set(data.keys()) - missing_keys = required_keys - available_keys - - if strict and missing_keys: - raise KeyError(f"Missing keys: {missing_keys!r}") + if strict: + self._check_keys(data) for key, transform in self.transform_map.items(): if key in data: data[key] = transform.inverse(data[key], **kwargs) return data + + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs + ) -> dict[str, np.ndarray]: + data = data.copy() + + if strict: + self._check_keys(data) + + for key, transform in self.transform_map.items(): + if key in data: + ldj = transform.log_det_jac(data[key], **kwargs) + if ldj is None: + continue + elif key in log_det_jac: + log_det_jac[key] += ldj + else: + log_det_jac[key] = ldj + + return log_det_jac + + def _check_keys(self, data: dict[str, np.ndarray]): + required_keys = set(self.transform_map.keys()) + available_keys = set(data.keys()) + missing_keys = required_keys - available_keys + + if missing_keys: + raise KeyError(f"Missing keys: {missing_keys!r}") diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py index 8d9dce1be..18e54a5ff 100644 --- a/bayesflow/adapters/transforms/scale.py +++ b/bayesflow/adapters/transforms/scale.py @@ -18,3 +18,7 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return data / self.scale + + def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: + ldj = np.log(np.abs(self.scale)) + return np.repeat(ldj, data.shape[0]) From 82109ad6867c976f0de3a28f3ef05697618f838e Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 10 Apr 2025 11:04:11 +0200 Subject: [PATCH 02/20] concatenate --- bayesflow/adapters/transforms/concatenate.py | 16 ++++++++++++++++ bayesflow/adapters/transforms/transform.py | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index deb54fc3f..5249bd0b4 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -115,3 +115,19 @@ def extra_repr(self) -> str: result += f", axis={self.axis}" return result + + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = False, **kwargs + ) -> dict[str, np.ndarray]: + # copy to avoid side effects + log_det_jac = log_det_jac.copy() + + required_keys = set(self.keys) + available_keys = set(log_det_jac.keys()) + common_keys = available_keys & required_keys + + parts = [log_det_jac.pop(key) for key in common_keys] + + log_det_jac[self.into] = sum(parts) + + return log_det_jac diff --git a/bayesflow/adapters/transforms/transform.py b/bayesflow/adapters/transforms/transform.py index 4642c1165..fc8d633b3 100644 --- a/bayesflow/adapters/transforms/transform.py +++ b/bayesflow/adapters/transforms/transform.py @@ -35,3 +35,8 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray def extra_repr(self) -> str: return "" + + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + return log_det_jac From df1a91a174eba503838063c9fb73f071f94ae337 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 10 Apr 2025 11:08:19 +0200 Subject: [PATCH 03/20] keep, drop, rename --- bayesflow/adapters/transforms/drop.py | 3 +++ bayesflow/adapters/transforms/keep.py | 3 +++ bayesflow/adapters/transforms/rename.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/bayesflow/adapters/transforms/drop.py b/bayesflow/adapters/transforms/drop.py index 51615d632..dfcdeb394 100644 --- a/bayesflow/adapters/transforms/drop.py +++ b/bayesflow/adapters/transforms/drop.py @@ -46,3 +46,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: def extra_repr(self) -> str: return "[" + ", ".join(map(repr, self.keys)) + "]" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs): + return self.forward(data=log_det_jac) diff --git a/bayesflow/adapters/transforms/keep.py b/bayesflow/adapters/transforms/keep.py index 62373071f..af3b02f10 100644 --- a/bayesflow/adapters/transforms/keep.py +++ b/bayesflow/adapters/transforms/keep.py @@ -57,3 +57,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: def extra_repr(self) -> str: return "[" + ", ".join(map(repr, self.keys)) + "]" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs): + return self.forward(data=log_det_jac) diff --git a/bayesflow/adapters/transforms/rename.py b/bayesflow/adapters/transforms/rename.py index 49cc52eba..22ac75102 100644 --- a/bayesflow/adapters/transforms/rename.py +++ b/bayesflow/adapters/transforms/rename.py @@ -58,3 +58,6 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di def extra_repr(self) -> str: return f"{self.from_key!r} -> {self.to_key!r}" + + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs): + return self.forward(data=log_det_jac, strict=False) From 324f2f726fc81446f9c6cf4f19d6e28b5d8ae38c Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 10 Apr 2025 11:33:23 +0200 Subject: [PATCH 04/20] scale, log, sqrt --- bayesflow/adapters/transforms/log.py | 4 ++++ bayesflow/adapters/transforms/scale.py | 3 ++- bayesflow/adapters/transforms/sqrt.py | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index 3184ab979..aea9f4761 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -37,3 +37,7 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def get_config(self) -> dict: return serialize({"p1": self.p1}) + + def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: + ldj = -np.log(data) + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py index 18e54a5ff..dba712904 100644 --- a/bayesflow/adapters/transforms/scale.py +++ b/bayesflow/adapters/transforms/scale.py @@ -21,4 +21,5 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: ldj = np.log(np.abs(self.scale)) - return np.repeat(ldj, data.shape[0]) + ldj = np.full(data.shape, ldj) + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index 617f892bc..c2f15ad83 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -22,3 +22,7 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def get_config(self) -> dict: return {} + + def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: + ldj = -0.5 * np.log(data) + 0.5 + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) From e91a93591cf8ee6f706f6654379e76799deb182b Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 10 Apr 2025 11:58:38 +0200 Subject: [PATCH 05/20] standardize --- .../adapters/transforms/filter_transform.py | 30 +++++++++++++++++-- bayesflow/adapters/transforms/standardize.py | 7 +++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index e1920e73c..4d167fdae 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -150,9 +150,35 @@ def _should_transform(self, key: str, value: np.ndarray, inverse: bool = False) return predicate(key, value, inverse=inverse) def _apply_transform(self, key: str, value: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: + transform = self._get_transform(key) + + return transform(value, inverse=inverse, **kwargs) + + def _get_transform(self, key: str) -> ElementwiseTransform: if key not in self.transform_map: self.transform_map[key] = self.transform_constructor(**self.kwargs) - transform = self.transform_map[key] + return self.transform_map[key] - return transform(value, inverse=inverse, **kwargs) + def log_det_jac( + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs + ): + data = data.copy() + + if strict and self.include is not None: + missing_keys = set(self.include) - set(data.keys()) + if missing_keys: + raise KeyError(f"Missing keys from include list: {missing_keys!r}") + + for key, value in data.items(): + if self._should_transform(key, value, inverse=False): + transform = self._get_transform(key) + ldj = transform.log_det_jac(value, **kwargs) + if ldj is None: + continue + elif key in log_det_jac: + log_det_jac[key] += ldj + else: + log_det_jac[key] = ldj + + return data diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py index 52899917c..901d104f0 100644 --- a/bayesflow/adapters/transforms/standardize.py +++ b/bayesflow/adapters/transforms/standardize.py @@ -120,3 +120,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: std = np.broadcast_to(self.std, data.shape) return data * std + mean + + def log_det_jac(self, data, **kwargs) -> np.ndarray: + if self.std is None: + return None + std = np.broadcast_to(self.std, data.shape) + ldj = np.log(np.abs(std)) + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) From 9f27782a9013c904112e7c8408166c024cd4700e Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 10 Apr 2025 13:14:32 +0200 Subject: [PATCH 06/20] constraint transforms --- bayesflow/adapters/adapter.py | 6 ++++- bayesflow/adapters/transforms/constrain.py | 27 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 722b96db6..40133fb24 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -99,8 +99,12 @@ def forward( The transformed data. """ data = data.copy() - log_det_jac = {} + if not jacobian: + for transform in self.transforms: + data = transform(data, **kwargs) + return data + log_det_jac = {} for transform in self.transforms: data = transform(data, stage=stage, **kwargs) log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs) diff --git a/bayesflow/adapters/transforms/constrain.py b/bayesflow/adapters/transforms/constrain.py index 5f93135a1..0598c2cdc 100644 --- a/bayesflow/adapters/transforms/constrain.py +++ b/bayesflow/adapters/transforms/constrain.py @@ -87,6 +87,11 @@ def constrain(x): def unconstrain(x): return inverse_sigmoid((x - lower) / (upper - lower)) + + def ldj(x): + x = (x - lower) / (upper - lower) + return -np.log(x) - np.log1p(-x) - np.log(upper - lower) + case str() as name: raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.") case other: @@ -101,6 +106,11 @@ def constrain(x): def unconstrain(x): return inverse_softplus(x - lower) + + def ldj(x): + x = x - lower + return x - np.log(np.exp(x) - 1) + case "exp" | "log": def constrain(x): @@ -108,6 +118,10 @@ def constrain(x): def unconstrain(x): return np.log(x - lower) + + def ldj(x): + return -np.log(x - lower) + case str() as name: raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.") case other: @@ -122,6 +136,11 @@ def constrain(x): def unconstrain(x): return -inverse_softplus(-(x - upper)) + + def ldj(x): + x = -(x - upper) + return x - np.log(np.exp(x) - 1) + case "exp" | "log": def constrain(x): @@ -129,6 +148,9 @@ def constrain(x): def unconstrain(x): return -np.log(-x + upper) + + def ldj(x): + return -np.log(-x + upper) case str() as name: raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.") case other: @@ -142,6 +164,7 @@ def unconstrain(x): self.constrain = constrain self.unconstrain = unconstrain + self.ldj = ldj # do this last to avoid serialization issues match inclusive: @@ -178,3 +201,7 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: # inverse means network space -> data space, so constrain the data return self.constrain(data) + + def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: + ldj = self.ldj(data) + return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) From d7276aab1dcfd909851effb81e60bc68f14d295a Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 10 Apr 2025 13:28:33 +0200 Subject: [PATCH 07/20] continuous approximator returns log_prob with volume correction --- bayesflow/approximators/continuous_approximator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index dcb661ca0..a88cabbe2 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -417,12 +417,12 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic np.ndarray Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))` """ - data = self.adapter(data, strict=False, stage="inference", **kwargs) + data, jacobian = self.adapter(data, strict=False, stage="inference", jacobian=True, **kwargs) data = keras.tree.map_structure(keras.ops.convert_to_tensor, data) log_prob = self._log_prob(**data, **kwargs) log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob) - return log_prob + return log_prob + jacobian["inference_variables"] def _log_prob( self, From dbf199b2515dc922d0cb9f08fe47562c55fdb498 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Tue, 15 Apr 2025 09:29:01 +0200 Subject: [PATCH 08/20] loop for inverse jacobian --- bayesflow/adapters/adapter.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 40133fb24..c60df7805 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -101,17 +101,19 @@ def forward( data = data.copy() if not jacobian: for transform in self.transforms: - data = transform(data, **kwargs) + data = transform(data, stage=stage, **kwargs) return data log_det_jac = {} for transform in self.transforms: - data = transform(data, stage=stage, **kwargs) log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs) + data = transform(data, stage=stage, **kwargs) return data, log_det_jac - def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", jacobian: bool = False, **kwargs) -> dict[str, any]: + def inverse( + self, data: dict[str, np.ndarray], *, stage: str = "inference", jacobian: bool = False, **kwargs + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the inverse direction. Parameters @@ -129,13 +131,17 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", jaco The transformed data. """ data = data.copy() - if jacobian: - data = self._init_jacobian(data) + if not jacobian: + for transform in reversed(self.transforms): + data = transform(data, stage=stage, inverse=True, **kwargs) + return data + log_det_jac = {} for transform in reversed(self.transforms): - data = transform(data, stage=stage, inverse=True, jacobian=jacobian, **kwargs) - - return data + data = transform(data, stage=stage, inverse=True, **kwargs) + log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs) + + return data, log_det_jac def __call__( self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs From f7aa2452cae0a104f65783f005c085d113f3e6e3 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Wed, 16 Apr 2025 12:41:15 +0200 Subject: [PATCH 09/20] inverse for elementwise --- bayesflow/adapters/transforms/constrain.py | 4 +++- bayesflow/adapters/transforms/elementwise_transform.py | 2 +- bayesflow/adapters/transforms/log.py | 4 +++- bayesflow/adapters/transforms/scale.py | 4 +++- bayesflow/adapters/transforms/sqrt.py | 4 +++- bayesflow/adapters/transforms/standardize.py | 4 +++- 6 files changed, 16 insertions(+), 6 deletions(-) diff --git a/bayesflow/adapters/transforms/constrain.py b/bayesflow/adapters/transforms/constrain.py index 0598c2cdc..a4ca0be25 100644 --- a/bayesflow/adapters/transforms/constrain.py +++ b/bayesflow/adapters/transforms/constrain.py @@ -202,6 +202,8 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: # inverse means network space -> data space, so constrain the data return self.constrain(data) - def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: ldj = self.ldj(data) + if inverse: + ldj = -ldj return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/elementwise_transform.py b/bayesflow/adapters/transforms/elementwise_transform.py index 9809a72f8..7d603d517 100644 --- a/bayesflow/adapters/transforms/elementwise_transform.py +++ b/bayesflow/adapters/transforms/elementwise_transform.py @@ -26,5 +26,5 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: raise NotImplementedError - def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray | None: + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray | None: return None diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index aea9f4761..2431bc3b1 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -38,6 +38,8 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def get_config(self) -> dict: return serialize({"p1": self.p1}) - def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: ldj = -np.log(data) + if inverse: + ldj = -ldj return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py index dba712904..96b2ff927 100644 --- a/bayesflow/adapters/transforms/scale.py +++ b/bayesflow/adapters/transforms/scale.py @@ -19,7 +19,9 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return data / self.scale - def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: ldj = np.log(np.abs(self.scale)) ldj = np.full(data.shape, ldj) + if inverse: + ldj = -ldj return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index c2f15ad83..e1d0fa2da 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -23,6 +23,8 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: def get_config(self) -> dict: return {} - def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray: + def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: ldj = -0.5 * np.log(data) + 0.5 + if inverse: + ldj = -ldj return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py index 901d104f0..5972aac02 100644 --- a/bayesflow/adapters/transforms/standardize.py +++ b/bayesflow/adapters/transforms/standardize.py @@ -121,9 +121,11 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return data * std + mean - def log_det_jac(self, data, **kwargs) -> np.ndarray: + def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray: if self.std is None: return None std = np.broadcast_to(self.std, data.shape) ldj = np.log(np.abs(std)) + if inverse: + ldj = -ldj return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) From 7c6120c4173fde8579ac663a577d55b7c28e0443 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Wed, 16 Apr 2025 13:04:34 +0200 Subject: [PATCH 10/20] inverse for Transforms --- bayesflow/adapters/transforms/concatenate.py | 17 ++++++++++++++++- bayesflow/adapters/transforms/drop.py | 4 ++-- bayesflow/adapters/transforms/keep.py | 4 ++-- bayesflow/adapters/transforms/rename.py | 4 ++-- bayesflow/adapters/transforms/transform.py | 2 +- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index 5249bd0b4..56eb82e1e 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -117,11 +117,26 @@ def extra_repr(self) -> str: return result def log_det_jac( - self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = False, **kwargs + self, + data: dict[str, np.ndarray], + log_det_jac: dict[str, np.ndarray], + *, + strict: bool = False, + inverse: bool = False, + **kwargs, ) -> dict[str, np.ndarray]: # copy to avoid side effects log_det_jac = log_det_jac.copy() + if inverse: + if log_det_jac.get(self.into) is not None: + raise ValueError( + "Cannot obtain an inverse jacobian of concatenation. " + "Transform your variables before you concatenate." + ) + + return log_det_jac + required_keys = set(self.keys) available_keys = set(log_det_jac.keys()) common_keys = available_keys & required_keys diff --git a/bayesflow/adapters/transforms/drop.py b/bayesflow/adapters/transforms/drop.py index dfcdeb394..91dcd6a28 100644 --- a/bayesflow/adapters/transforms/drop.py +++ b/bayesflow/adapters/transforms/drop.py @@ -47,5 +47,5 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: def extra_repr(self) -> str: return "[" + ", ".join(map(repr, self.keys)) + "]" - def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs): - return self.forward(data=log_det_jac) + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac) diff --git a/bayesflow/adapters/transforms/keep.py b/bayesflow/adapters/transforms/keep.py index af3b02f10..56f395166 100644 --- a/bayesflow/adapters/transforms/keep.py +++ b/bayesflow/adapters/transforms/keep.py @@ -58,5 +58,5 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: def extra_repr(self) -> str: return "[" + ", ".join(map(repr, self.keys)) + "]" - def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs): - return self.forward(data=log_det_jac) + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac) diff --git a/bayesflow/adapters/transforms/rename.py b/bayesflow/adapters/transforms/rename.py index 22ac75102..746ef5a80 100644 --- a/bayesflow/adapters/transforms/rename.py +++ b/bayesflow/adapters/transforms/rename.py @@ -59,5 +59,5 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di def extra_repr(self) -> str: return f"{self.from_key!r} -> {self.to_key!r}" - def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs): - return self.forward(data=log_det_jac, strict=False) + def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs): + return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False) diff --git a/bayesflow/adapters/transforms/transform.py b/bayesflow/adapters/transforms/transform.py index fc8d633b3..ed3058e15 100644 --- a/bayesflow/adapters/transforms/transform.py +++ b/bayesflow/adapters/transforms/transform.py @@ -37,6 +37,6 @@ def extra_repr(self) -> str: return "" def log_det_jac( - self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], **kwargs + self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], inverse: bool = False, **kwargs ) -> dict[str, np.ndarray]: return log_det_jac From 6d0546421b3c4d4e3151b6fd68c09a07b7b13f7a Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 17 Apr 2025 11:17:35 +0200 Subject: [PATCH 11/20] raise error with numpy transform (for now) --- bayesflow/adapters/transforms/numpy_transform.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py index aecf03bba..4c9fe3078 100644 --- a/bayesflow/adapters/transforms/numpy_transform.py +++ b/bayesflow/adapters/transforms/numpy_transform.py @@ -72,3 +72,6 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]: def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return self._inverse(data) + + def log_det_jac(self, data, inverse=False, **kwargs): + raise NotImplementedError("Jacobian of the numpy transforms are not implemented yet") From 37dce55778f83cb674c189a69970d7b479d15c9c Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 17 Apr 2025 11:23:22 +0200 Subject: [PATCH 12/20] do not fail if no transform is used --- bayesflow/approximators/continuous_approximator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index a88cabbe2..bdfd1414b 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -422,7 +422,11 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic log_prob = self._log_prob(**data, **kwargs) log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob) - return log_prob + jacobian["inference_variables"] + jacobian = jacobian.get("inference_variables") + if jacobian is not None: + log_prob = log_prob + jacobian + + return log_prob def _log_prob( self, From 4e3967db0d4c95e764d527c999e4b947276ad3b3 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 17 Apr 2025 11:27:13 +0200 Subject: [PATCH 13/20] take care of log1p as well --- bayesflow/adapters/transforms/log.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py index 2431bc3b1..d5f559b4f 100644 --- a/bayesflow/adapters/transforms/log.py +++ b/bayesflow/adapters/transforms/log.py @@ -39,7 +39,10 @@ def get_config(self) -> dict: return serialize({"p1": self.p1}) def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: - ldj = -np.log(data) + if self.p1: + ldj = -np.log1p(data) + else: + ldj = -np.log(data) if inverse: ldj = -ldj return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) From 632f1086ff0a8a43dcabd10d4039aa804414ecd9 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 17 Apr 2025 11:49:45 +0200 Subject: [PATCH 14/20] fix filter transforms, boundary condition --- bayesflow/adapters/transforms/concatenate.py | 3 +++ bayesflow/adapters/transforms/filter_transform.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index 56eb82e1e..1f1c84a4c 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -141,6 +141,9 @@ def log_det_jac( available_keys = set(log_det_jac.keys()) common_keys = available_keys & required_keys + if len(common_keys) == 0: + return log_det_jac + parts = [log_det_jac.pop(key) for key in common_keys] log_det_jac[self.into] = sum(parts) diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index 4d167fdae..7eccf370b 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -181,4 +181,4 @@ def log_det_jac( else: log_det_jac[key] = ldj - return data + return log_det_jac From d8a8dedbb24ec7cdaa2d04af15629f9f5bc8b6d7 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Wed, 23 Apr 2025 12:01:08 +0200 Subject: [PATCH 15/20] add tests for adapter jacobians --- tests/test_adapters/conftest.py | 40 ++++++++++++++++++++++++++++ tests/test_adapters/test_adapters.py | 35 +++++++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 873279f09..26da72e5b 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -49,6 +49,8 @@ def random_data(): "z1": np.random.standard_normal(size=(32, 2)), "p1": np.random.lognormal(size=(32, 2)), "p2": np.random.lognormal(size=(32, 2)), + "p3": np.random.lognormal(size=(32, 2)), + "n1": 1 - np.random.lognormal(size=(32, 2)), "s1": np.random.standard_normal(size=(32, 3, 2)), "s2": np.random.standard_normal(size=(32, 3, 2)), "t1": np.zeros((3, 2)), @@ -56,5 +58,43 @@ def random_data(): "d1": np.random.standard_normal(size=(32, 2)), "d2": np.random.standard_normal(size=(32, 2)), "o1": np.random.randint(0, 9, size=(32, 2)), + "u1": np.random.uniform(low=-1, high=2, size=(32, 1)), "key_to_split": np.random.standard_normal(size=(32, 10)), } + + +@pytest.fixture() +def adapter_jacobian(): + from bayesflow.adapters import Adapter + + adapter = ( + Adapter() + .scale("x1", by=2) + .log("p1", p1=True) + .sqrt("p2") + .constrain("p3", lower=0) + .constrain("n1", upper=1) + .constrain("u1", lower=-1, upper=2) + .concatenate(["p1", "p2", "p3"], into="p") + .rename("u1", "u") + ) + + return adapter + + +@pytest.fixture() +def adapter_jacobian_inverse(): + from bayesflow.adapters import Adapter + + adapter = ( + Adapter() + .standardize("x1", mean=1, std=2) + .log("p1") + .sqrt("p2") + .constrain("p3", lower=0, method="log") + .constrain("n1", upper=1, method="log") + .constrain("u1", lower=-1, upper=2) + .scale(["p1", "p2", "p3"], by=3.5) + ) + + return adapter diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index d6215170e..6466564bd 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -13,7 +13,7 @@ def test_cycle_consistency(adapter, random_data): deprocessed = adapter(processed, inverse=True) for key, value in random_data.items(): - if key in ["d1", "d2"]: + if key in ["d1", "d2", "p3", "n1", "u1"]: # dropped continue assert key in deprocessed @@ -230,3 +230,36 @@ def test_to_dict_transform(): # category should have 5 one-hot categories, even though it was only passed 4 assert processed["category"].shape[-1] == 5 + + +def test_jacobian(adapter_jacobian, random_data): + d, jacobian = adapter_jacobian(random_data, jacobian=True) + + assert np.allclose(jacobian["x1"], np.log(2)) + + p1 = -np.log1p(random_data["p1"]) + p2 = -0.5 * np.log(random_data["p2"]) + 0.5 + p3 = random_data["p3"] - np.log(np.exp(random_data["p3"]) - 1) + p = np.sum(p1, axis=-1) + np.sum(p2, axis=-1) + np.sum(p3, axis=-1) + + assert np.allclose(jacobian["p"], p) + + n1 = -(random_data["n1"] - 1) + n1 = n1 - np.log(np.exp(n1) - 1) + n1 = np.sum(n1, axis=-1) + + assert np.allclose(jacobian["n1"], n1) + + u1 = random_data["u1"] + u1 = (u1 + 1) / 3 + u1 = -np.log(u1) - np.log1p(-u1) - np.log(3) + + assert np.allclose(jacobian["u"], u1[:, 0]) + + +def test_jacobian_inverse(adapter_jacobian_inverse, random_data): + d, forward_jacobian = adapter_jacobian_inverse(random_data, jacobian=True) + d, inverse_jacobian = adapter_jacobian_inverse(d, inverse=True, jacobian=True) + + for key in forward_jacobian.keys(): + assert np.allclose(forward_jacobian[key], -inverse_jacobian[key]) From 399c6dcf3fbbeb9ffd5f804e88413333fa2dcb96 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Wed, 23 Apr 2025 12:44:32 +0200 Subject: [PATCH 16/20] document jacobian arg --- bayesflow/adapters/adapter.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index c60df7805..ff11317d2 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -90,13 +90,15 @@ def forward( The data to be transformed. stage : str, one of ["training", "validation", "inference"] The stage the function is called in. + jacobian: bool, optional + Whether to return the log determinant jacobians of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and jacobians. """ data = data.copy() if not jacobian: @@ -122,13 +124,15 @@ def inverse( The data to be transformed. stage : str, one of ["training", "validation", "inference"] The stage the function is called in. + jacobian: bool, optional + Whether to return the log determinant jacobians of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and jacobians. """ data = data.copy() if not jacobian: @@ -145,7 +149,7 @@ def inverse( def __call__( self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs - ) -> dict[str, np.ndarray]: + ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the given direction. Parameters @@ -161,8 +165,8 @@ def __call__( Returns ------- - dict - The transformed data. + dict | tuple[dict, dict] + The transformed data or tuple of transformed data and jacobians. """ if inverse: return self.inverse(data, stage=stage, **kwargs) From b9b00a49117dd4a74ea6e714a6f70e53e81d2ed8 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 24 Apr 2025 08:38:20 +0200 Subject: [PATCH 17/20] jacobian -> log_det_jac --- bayesflow/adapters/adapter.py | 18 +++++++------- bayesflow/adapters/transforms/concatenate.py | 2 +- .../adapters/transforms/numpy_transform.py | 2 +- bayesflow/adapters/transforms/sqrt.py | 2 +- .../approximators/continuous_approximator.py | 9 +++---- tests/test_adapters/conftest.py | 4 ++-- tests/test_adapters/test_adapters.py | 24 +++++++++---------- 7 files changed, 31 insertions(+), 30 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index ff11317d2..440ac635a 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -80,7 +80,7 @@ def get_config(self) -> dict: return serialize(config) def forward( - self, data: dict[str, any], *, stage: str = "inference", jacobian: bool = False, **kwargs + self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the forward direction. @@ -90,7 +90,7 @@ def forward( The data to be transformed. stage : str, one of ["training", "validation", "inference"] The stage the function is called in. - jacobian: bool, optional + log_det_jac: bool, optional Whether to return the log determinant jacobians of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. @@ -98,10 +98,10 @@ def forward( Returns ------- dict | tuple[dict, dict] - The transformed data or tuple of transformed data and jacobians. + The transformed data or tuple of transformed data and log determinant jacobians. """ data = data.copy() - if not jacobian: + if not log_det_jac: for transform in self.transforms: data = transform(data, stage=stage, **kwargs) return data @@ -114,7 +114,7 @@ def forward( return data, log_det_jac def inverse( - self, data: dict[str, np.ndarray], *, stage: str = "inference", jacobian: bool = False, **kwargs + self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """Apply the transforms in the inverse direction. @@ -124,7 +124,7 @@ def inverse( The data to be transformed. stage : str, one of ["training", "validation", "inference"] The stage the function is called in. - jacobian: bool, optional + log_det_jac: bool, optional Whether to return the log determinant jacobians of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. @@ -132,10 +132,10 @@ def inverse( Returns ------- dict | tuple[dict, dict] - The transformed data or tuple of transformed data and jacobians. + The transformed data or tuple of transformed data and log determinant jacobians. """ data = data.copy() - if not jacobian: + if not log_det_jac: for transform in reversed(self.transforms): data = transform(data, stage=stage, inverse=True, **kwargs) return data @@ -166,7 +166,7 @@ def __call__( Returns ------- dict | tuple[dict, dict] - The transformed data or tuple of transformed data and jacobians. + The transformed data or tuple of transformed data and log determinant jacobians. """ if inverse: return self.inverse(data, stage=stage, **kwargs) diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index 1f1c84a4c..4c0361af8 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -131,7 +131,7 @@ def log_det_jac( if inverse: if log_det_jac.get(self.into) is not None: raise ValueError( - "Cannot obtain an inverse jacobian of concatenation. " + "Cannot obtain an inverse jacobian determinant of concatenation. " "Transform your variables before you concatenate." ) diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py index 4c9fe3078..c01cd60b3 100644 --- a/bayesflow/adapters/transforms/numpy_transform.py +++ b/bayesflow/adapters/transforms/numpy_transform.py @@ -74,4 +74,4 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return self._inverse(data) def log_det_jac(self, data, inverse=False, **kwargs): - raise NotImplementedError("Jacobian of the numpy transforms are not implemented yet") + raise NotImplementedError("Log determinand jacobian of the numpy transforms are not implemented yet") diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py index e1d0fa2da..4ef1370dc 100644 --- a/bayesflow/adapters/transforms/sqrt.py +++ b/bayesflow/adapters/transforms/sqrt.py @@ -24,7 +24,7 @@ def get_config(self) -> dict: return {} def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray: - ldj = -0.5 * np.log(data) + 0.5 + ldj = -0.5 * np.log(data) - np.log(2) if inverse: ldj = -ldj return np.sum(ldj, axis=tuple(range(1, ldj.ndim))) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index bdfd1414b..bf4e263a0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -417,14 +417,15 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic np.ndarray Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))` """ - data, jacobian = self.adapter(data, strict=False, stage="inference", jacobian=True, **kwargs) + data, log_det_jac = self.adapter(data, strict=False, stage="inference", log_det_jac=True, **kwargs) data = keras.tree.map_structure(keras.ops.convert_to_tensor, data) log_prob = self._log_prob(**data, **kwargs) log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob) - jacobian = jacobian.get("inference_variables") - if jacobian is not None: - log_prob = log_prob + jacobian + # change of variables formula + log_det_jac = log_det_jac.get("inference_variables") + if log_det_jac is not None: + log_prob = log_prob + log_det_jac return log_prob diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 26da72e5b..d69cd4be4 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -64,7 +64,7 @@ def random_data(): @pytest.fixture() -def adapter_jacobian(): +def adapter_log_det_jac(): from bayesflow.adapters import Adapter adapter = ( @@ -83,7 +83,7 @@ def adapter_jacobian(): @pytest.fixture() -def adapter_jacobian_inverse(): +def adapter_log_det_jac_inverse(): from bayesflow.adapters import Adapter adapter = ( diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 6466564bd..5b5011b71 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -232,34 +232,34 @@ def test_to_dict_transform(): assert processed["category"].shape[-1] == 5 -def test_jacobian(adapter_jacobian, random_data): - d, jacobian = adapter_jacobian(random_data, jacobian=True) +def test_log_det_jac(adapter_log_det_jac, random_data): + d, log_det_jac = adapter_log_det_jac(random_data, log_det_jac=True) - assert np.allclose(jacobian["x1"], np.log(2)) + assert np.allclose(log_det_jac["x1"], np.log(2)) p1 = -np.log1p(random_data["p1"]) - p2 = -0.5 * np.log(random_data["p2"]) + 0.5 + p2 = -0.5 * np.log(random_data["p2"]) - np.log(2) p3 = random_data["p3"] - np.log(np.exp(random_data["p3"]) - 1) p = np.sum(p1, axis=-1) + np.sum(p2, axis=-1) + np.sum(p3, axis=-1) - assert np.allclose(jacobian["p"], p) + assert np.allclose(log_det_jac["p"], p) n1 = -(random_data["n1"] - 1) n1 = n1 - np.log(np.exp(n1) - 1) n1 = np.sum(n1, axis=-1) - assert np.allclose(jacobian["n1"], n1) + assert np.allclose(log_det_jac["n1"], n1) u1 = random_data["u1"] u1 = (u1 + 1) / 3 u1 = -np.log(u1) - np.log1p(-u1) - np.log(3) - assert np.allclose(jacobian["u"], u1[:, 0]) + assert np.allclose(log_det_jac["u"], u1[:, 0]) -def test_jacobian_inverse(adapter_jacobian_inverse, random_data): - d, forward_jacobian = adapter_jacobian_inverse(random_data, jacobian=True) - d, inverse_jacobian = adapter_jacobian_inverse(d, inverse=True, jacobian=True) +def test_log_det_jac_inverse(adapter_log_det_jac_inverse, random_data): + d, forward_log_det_jac = adapter_log_det_jac_inverse(random_data, log_det_jac=True) + d, inverse_log_det_jac = adapter_log_det_jac_inverse(d, inverse=True, log_det_jac=True) - for key in forward_jacobian.keys(): - assert np.allclose(forward_jacobian[key], -inverse_jacobian[key]) + for key in forward_log_det_jac.keys(): + assert np.allclose(forward_log_det_jac[key], -inverse_log_det_jac[key]) From 6a58f921bdd904ad874086ab1c6e55f594164a37 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 24 Apr 2025 09:10:17 +0200 Subject: [PATCH 18/20] add test for inverse concatenation --- bayesflow/adapters/adapter.py | 10 +++---- bayesflow/adapters/transforms/concatenate.py | 2 +- .../adapters/transforms/numpy_transform.py | 2 +- tests/test_adapters/test_adapters.py | 26 +++++++++++++++++++ 4 files changed, 33 insertions(+), 7 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 440ac635a..6c7266b7a 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -91,14 +91,14 @@ def forward( stage : str, one of ["training", "validation", "inference"] The stage the function is called in. log_det_jac: bool, optional - Whether to return the log determinant jacobians of the transforms. + Whether to return the log determinant of the Jacobian of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. Returns ------- dict | tuple[dict, dict] - The transformed data or tuple of transformed data and log determinant jacobians. + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ data = data.copy() if not log_det_jac: @@ -125,14 +125,14 @@ def inverse( stage : str, one of ["training", "validation", "inference"] The stage the function is called in. log_det_jac: bool, optional - Whether to return the log determinant jacobians of the transforms. + Whether to return the log determinant of the Jacobian of the transforms. **kwargs : dict Additional keyword arguments passed to each transform. Returns ------- dict | tuple[dict, dict] - The transformed data or tuple of transformed data and log determinant jacobians. + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ data = data.copy() if not log_det_jac: @@ -166,7 +166,7 @@ def __call__( Returns ------- dict | tuple[dict, dict] - The transformed data or tuple of transformed data and log determinant jacobians. + The transformed data or tuple of transformed data and log determinant of the Jacobian. """ if inverse: return self.inverse(data, stage=stage, **kwargs) diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index 4c0361af8..91ea9178b 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -131,7 +131,7 @@ def log_det_jac( if inverse: if log_det_jac.get(self.into) is not None: raise ValueError( - "Cannot obtain an inverse jacobian determinant of concatenation. " + "Cannot obtain an inverse Jacobian of concatenation. " "Transform your variables before you concatenate." ) diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py index c01cd60b3..29d25dc67 100644 --- a/bayesflow/adapters/transforms/numpy_transform.py +++ b/bayesflow/adapters/transforms/numpy_transform.py @@ -74,4 +74,4 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return self._inverse(data) def log_det_jac(self, data, inverse=False, **kwargs): - raise NotImplementedError("Log determinand jacobian of the numpy transforms are not implemented yet") + raise NotImplementedError("log determinant of the Jacobian of the numpy transforms are not implemented yet") diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 5b5011b71..1784befb7 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -263,3 +263,29 @@ def test_log_det_jac_inverse(adapter_log_det_jac_inverse, random_data): for key in forward_log_det_jac.keys(): assert np.allclose(forward_log_det_jac[key], -inverse_log_det_jac[key]) + + +def test_log_det_jac_exceptions(random_data): + # Test cannot compute inverse log_det_jac + # e.g., when we apply a concat and then a transform that + # we cannot "unconcatenate" the log_det_jac + # (because the log_det_jac are summed, not concatenated) + adapter = bf.Adapter().concatenate(["p1", "p2", "p3"], into="p").sqrt("p") + transformed_data, log_det_jac = adapter(random_data, log_det_jac=True) + + # test that inverse raises error + with pytest.raises(ValueError): + adapter(transformed_data, inverse=True, log_det_jac=True) + + # test resolvable order: first transform, then concatenate + adapter = bf.Adapter().sqrt(["p1", "p2", "p3"]).concatenate(["p1", "p2", "p3"], into="p") + + transformed_data, forward_log_det_jac = adapter(random_data, log_det_jac=True) + data, inverse_log_det_jac = adapter(transformed_data, inverse=True, log_det_jac=True) + inverse_log_det_jac = sum(inverse_log_det_jac.values()) + + # forward is the same regardless + assert np.allclose(forward_log_det_jac["p"], log_det_jac["p"]) + + # inverse works when concatenation is used after transforms + assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac) From f137cf8fe1afcab1d8733dc02fe6954c6ecaea66 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Thu, 24 Apr 2025 09:17:56 +0200 Subject: [PATCH 19/20] fix standardize --- bayesflow/adapters/adapter.py | 3 ++- bayesflow/adapters/transforms/standardize.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 6c7266b7a..ab6800d8a 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -108,8 +108,9 @@ def forward( log_det_jac = {} for transform in self.transforms: + transformed_data = transform(data, stage=stage, **kwargs) log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs) - data = transform(data, stage=stage, **kwargs) + data = transformed_data return data, log_det_jac diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py index 5972aac02..9699819b9 100644 --- a/bayesflow/adapters/transforms/standardize.py +++ b/bayesflow/adapters/transforms/standardize.py @@ -122,8 +122,6 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: return data * std + mean def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray: - if self.std is None: - return None std = np.broadcast_to(self.std, data.shape) ldj = np.log(np.abs(std)) if inverse: From d8710dcf85857f91c3d1d1021c0bf836bf6ba77d Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Mon, 28 Apr 2025 09:17:48 +0200 Subject: [PATCH 20/20] correct nesting in map_transform --- bayesflow/adapters/transforms/map_transform.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bayesflow/adapters/transforms/map_transform.py b/bayesflow/adapters/transforms/map_transform.py index eabfeb927..5da8292af 100644 --- a/bayesflow/adapters/transforms/map_transform.py +++ b/bayesflow/adapters/transforms/map_transform.py @@ -73,12 +73,13 @@ def log_det_jac( for key, transform in self.transform_map.items(): if key in data: ldj = transform.log_det_jac(data[key], **kwargs) - if ldj is None: - continue - elif key in log_det_jac: - log_det_jac[key] += ldj - else: - log_det_jac[key] = ldj + + if ldj is None: + continue + elif key in log_det_jac: + log_det_jac[key] += ldj + else: + log_det_jac[key] = ldj return log_det_jac