@@ -75,19 +75,20 @@ def call(
7575 If True, apply standardization: (x - mean) / std. Otherwise, inverse transform.
7676 log_det_jac : bool, optional
7777 Whether to return the log determinant of the Jacobian. Default is False.
78+ transformation_type: str, optional
79+ The type of inverse transform to apply. Only relevant if used with arbitrary point estimates.
80+ Default is "rank1+shift", i.e., undo standardization.
7881
7982 Returns
8083 -------
8184 Tensor or Sequence[Tensor]
8285 Transformed tensor, and optionally the log-determinant if `log_det_jac=True`.
8386 """
84- msg = """
85- Non-default transformation (i.e. transformation_type != "rank1+shift")
86- is not supported for forward or log_det_jac.
87- """
88- if forward or log_det_jac :
89- if transformation_type != "rank1+shift" : # non default transformation
90- raise ValueError (msg )
87+ if (forward or log_det_jac ) and transformation_type != "rank1+shift" :
88+ raise ValueError (
89+ 'Non-default transformation (i.e. transformation_type != "rank1+shift") '
90+ "is not supported for forward or log_det_jac."
91+ )
9192
9293 flattened = keras .tree .flatten (x )
9394 outputs , log_det_jacs = [], []
@@ -112,6 +113,8 @@ def call(
112113 case "rank02" :
113114 # x_ij = x_ij * sigma_i * sigma_j
114115 out = val * std * keras .ops .moveaxis (std , - 1 , - 2 )
116+ case _:
117+ out = val
115118
116119 outputs .append (out )
117120
0 commit comments