Skip to content

Commit 399c6dc

Browse files
committed
document jacobian arg
1 parent d8a8ded commit 399c6dc

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

bayesflow/adapters/adapter.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,15 @@ def forward(
9090
The data to be transformed.
9191
stage : str, one of ["training", "validation", "inference"]
9292
The stage the function is called in.
93+
jacobian: bool, optional
94+
Whether to return the log determinant jacobians of the transforms.
9395
**kwargs : dict
9496
Additional keyword arguments passed to each transform.
9597
9698
Returns
9799
-------
98-
dict
99-
The transformed data.
100+
dict | tuple[dict, dict]
101+
The transformed data or tuple of transformed data and jacobians.
100102
"""
101103
data = data.copy()
102104
if not jacobian:
@@ -122,13 +124,15 @@ def inverse(
122124
The data to be transformed.
123125
stage : str, one of ["training", "validation", "inference"]
124126
The stage the function is called in.
127+
jacobian: bool, optional
128+
Whether to return the log determinant jacobians of the transforms.
125129
**kwargs : dict
126130
Additional keyword arguments passed to each transform.
127131
128132
Returns
129133
-------
130-
dict
131-
The transformed data.
134+
dict | tuple[dict, dict]
135+
The transformed data or tuple of transformed data and jacobians.
132136
"""
133137
data = data.copy()
134138
if not jacobian:
@@ -145,7 +149,7 @@ def inverse(
145149

146150
def __call__(
147151
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
148-
) -> dict[str, np.ndarray]:
152+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
149153
"""Apply the transforms in the given direction.
150154
151155
Parameters
@@ -161,8 +165,8 @@ def __call__(
161165
162166
Returns
163167
-------
164-
dict
165-
The transformed data.
168+
dict | tuple[dict, dict]
169+
The transformed data or tuple of transformed data and jacobians.
166170
"""
167171
if inverse:
168172
return self.inverse(data, stage=stage, **kwargs)

0 commit comments

Comments
 (0)