Skip to content

Commit 1634bdd

Browse files
committed
loop for inverse jacobian
1 parent 5b9fdf4 commit 1634bdd

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

bayesflow/adapters/adapter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def forward(
109109

110110
return data, log_det_jac
111111

112-
def inverse(self, data: dict[str, np.ndarray], jacobian: bool = False, **kwargs) -> dict[str, any]:
112+
def inverse(
113+
self, data: dict[str, np.ndarray], jacobian: bool = False, **kwargs
114+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
113115
"""Apply the transforms in the inverse direction.
114116
115117
Parameters
@@ -125,13 +127,17 @@ def inverse(self, data: dict[str, np.ndarray], jacobian: bool = False, **kwargs)
125127
The transformed data.
126128
"""
127129
data = data.copy()
128-
if jacobian:
129-
data = self._init_jacobian(data)
130+
if not jacobian:
131+
for transform in reversed(self.transforms):
132+
data = transform(data, inverse=True, **kwargs)
133+
return data
130134

135+
log_det_jac = {}
131136
for transform in reversed(self.transforms):
132-
data = transform(data, inverse=True, jacobian=jacobian, **kwargs)
137+
data = transform(data, inverse=True, **kwargs)
138+
log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)
133139

134-
return data
140+
return data, log_det_jac
135141

136142
def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
137143
"""Apply the transforms in the given direction.

0 commit comments

Comments
 (0)