Skip to content

Commit 0fcd19f

Browse files
committed
inverse for elementwise
1 parent 1634bdd commit 0fcd19f

File tree

6 files changed

+16
-6
lines changed

6 files changed

+16
-6
lines changed

bayesflow/adapters/transforms/constrain.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
207207
# inverse means network space -> data space, so constrain the data
208208
return self.constrain(data)
209209

210-
def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray:
210+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
211211
ldj = self.ldj(data)
212+
if inverse:
213+
ldj = -ldj
212214
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/adapters/transforms/elementwise_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
2525
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
2626
raise NotImplementedError
2727

28-
def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray | None:
28+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray | None:
2929
return None

bayesflow/adapters/transforms/log.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def get_config(self) -> dict:
4949
"p1": serialize(self.p1),
5050
}
5151

52-
def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray:
52+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
5353
ldj = -np.log(data)
54+
if inverse:
55+
ldj = -ldj
5456
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/adapters/transforms/scale.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
2626
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
2727
return data / self.scale
2828

29-
def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray:
29+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
3030
ldj = np.log(np.abs(self.scale))
3131
ldj = np.full(data.shape, ldj)
32+
if inverse:
33+
ldj = -ldj
3234
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/adapters/transforms/sqrt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def from_config(cls, config: dict, custom_objects=None) -> "Sqrt":
2626
def get_config(self) -> dict:
2727
return {}
2828

29-
def log_det_jac(self, data: np.ndarray, **kwargs) -> np.ndarray:
29+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
3030
ldj = -0.5 * np.log(data) + 0.5
31+
if inverse:
32+
ldj = -ldj
3133
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/adapters/transforms/standardize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,11 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
130130

131131
return data * std + mean
132132

133-
def log_det_jac(self, data, **kwargs) -> np.ndarray:
133+
def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray:
134134
if self.std is None:
135135
return None
136136
std = np.broadcast_to(self.std, data.shape)
137137
ldj = np.log(np.abs(std))
138+
if inverse:
139+
ldj = -ldj
138140
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

0 commit comments

Comments
 (0)