@@ -109,7 +109,9 @@ def forward(
109109
110110 return data , log_det_jac
111111
112- def inverse (self , data : dict [str , np .ndarray ], jacobian : bool = False , ** kwargs ) -> dict [str , any ]:
112+ def inverse (
113+ self , data : dict [str , np .ndarray ], jacobian : bool = False , ** kwargs
114+ ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
113115 """Apply the transforms in the inverse direction.
114116
115117 Parameters
@@ -125,13 +127,17 @@ def inverse(self, data: dict[str, np.ndarray], jacobian: bool = False, **kwargs)
125127 The transformed data.
126128 """
127129 data = data .copy ()
128- if jacobian :
129- data = self ._init_jacobian (data )
130+ if not jacobian :
131+ for transform in reversed (self .transforms ):
132+ data = transform (data , inverse = True , ** kwargs )
133+ return data
130134
135+ log_det_jac = {}
131136 for transform in reversed (self .transforms ):
132- data = transform (data , inverse = True , jacobian = jacobian , ** kwargs )
137+ data = transform (data , inverse = True , ** kwargs )
138+ log_det_jac = transform .log_det_jac (data , log_det_jac , inverse = True , ** kwargs )
133139
134- return data
140+ return data , log_det_jac
135141
136142 def __call__ (self , data : Mapping [str , any ], * , inverse : bool = False , ** kwargs ) -> dict [str , np .ndarray ]:
137143 """Apply the transforms in the given direction.
0 commit comments