Skip to content

Commit e4bbdc7

Browse files
committed
inverse for Transforms
1 parent 0fcd19f commit e4bbdc7

File tree

5 files changed

+23
-8
lines changed

5 files changed

+23
-8
lines changed

bayesflow/adapters/transforms/concatenate.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,26 @@ def extra_repr(self) -> str:
128128
return result
129129

130130
def log_det_jac(
131-
self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = False, **kwargs
131+
self,
132+
data: dict[str, np.ndarray],
133+
log_det_jac: dict[str, np.ndarray],
134+
*,
135+
strict: bool = False,
136+
inverse: bool = False,
137+
**kwargs,
132138
) -> dict[str, np.ndarray]:
133139
# copy to avoid side effects
134140
log_det_jac = log_det_jac.copy()
135141

142+
if inverse:
143+
if log_det_jac.get(self.into) is not None:
144+
raise ValueError(
145+
"Cannot obtain an inverse jacobian of concatenation. "
146+
"Transform your variables before you concatenate."
147+
)
148+
149+
return log_det_jac
150+
136151
required_keys = set(self.keys)
137152
available_keys = set(log_det_jac.keys())
138153
common_keys = available_keys & required_keys

bayesflow/adapters/transforms/drop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
5555
def extra_repr(self) -> str:
5656
return "[" + ", ".join(map(repr, self.keys)) + "]"
5757

58-
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs):
59-
return self.forward(data=log_det_jac)
58+
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
59+
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)

bayesflow/adapters/transforms/keep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
6666
def extra_repr(self) -> str:
6767
return "[" + ", ".join(map(repr, self.keys)) + "]"
6868

69-
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs):
70-
return self.forward(data=log_det_jac)
69+
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
70+
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)

bayesflow/adapters/transforms/rename.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,5 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di
6868
def extra_repr(self) -> str:
6969
return f"{self.from_key!r} -> {self.to_key!r}"
7070

71-
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], **kwargs):
72-
return self.forward(data=log_det_jac, strict=False)
71+
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
72+
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False)

bayesflow/adapters/transforms/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ def extra_repr(self) -> str:
3636
return ""
3737

3838
def log_det_jac(
39-
self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], **kwargs
39+
self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], inverse: bool = False, **kwargs
4040
) -> dict[str, np.ndarray]:
4141
return log_det_jac

0 commit comments

Comments
 (0)