Skip to content

Commit c6d79ae

Browse files
committed
Add default case for std transform and add transformation to doc.
1 parent 40d2d1d commit c6d79ae

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

bayesflow/networks/standardization/standardization.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,20 @@ def call(
7575
If True, apply standardization: (x - mean) / std. Otherwise, inverse transform.
7676
log_det_jac : bool, optional
7777
Whether to return the log determinant of the Jacobian. Default is False.
78+
transformation_type: str, optional
79+
The type of inverse transform to apply. Only relevant if used with arbitrary point estimates.
80+
Default is "rank1+shift", i.e., undo standardization.
7881
7982
Returns
8083
-------
8184
Tensor or Sequence[Tensor]
8285
Transformed tensor, and optionally the log-determinant if `log_det_jac=True`.
8386
"""
84-
msg = """
85-
Non-default transformation (i.e. transformation_type != "rank1+shift")
86-
is not supported for forward or log_det_jac.
87-
"""
88-
if forward or log_det_jac:
89-
if transformation_type != "rank1+shift": # non default transformation
90-
raise ValueError(msg)
87+
if (forward or log_det_jac) and transformation_type != "rank1+shift":
88+
raise ValueError(
89+
'Non-default transformation (i.e. transformation_type != "rank1+shift") '
90+
"is not supported for forward or log_det_jac."
91+
)
9192

9293
flattened = keras.tree.flatten(x)
9394
outputs, log_det_jacs = [], []
@@ -112,6 +113,8 @@ def call(
112113
case "rank02":
113114
# x_ij = x_ij * sigma_i * sigma_j
114115
out = val * std * keras.ops.moveaxis(std, -1, -2)
116+
case _:
117+
out = val
115118

116119
outputs.append(out)
117120

0 commit comments

Comments
 (0)