@@ -47,6 +47,8 @@ class Transform(ABC):
4747 Apply the forward transformation :math:`y = f(x)`
4848 inverse(y)
4949 Apply the inverse transformation :math:`x = f^{-1}(y)`
50+ log_abs_det_jacobian(x, y)
51+ Compute the log absolute determinant of the Jacobian
5052
5153 Notes
5254 -----
@@ -64,6 +66,10 @@ class Transform(ABC):
6466 """
6567 __module__ = 'braintools.param'
6668
69+ def __repr__ (self ) -> str :
70+ """Return a string representation of the transform."""
71+ return f"{ self .__class__ .__name__ } ()"
72+
6773 def __call__ (self , x : ArrayLike ) -> Array :
6874 r"""
6975 Apply the forward transformation to the input.
@@ -136,12 +142,60 @@ def inverse(self, y: ArrayLike) -> Array:
136142 """
137143 pass
138144
145+ def log_abs_det_jacobian (self , x : ArrayLike , y : ArrayLike ) -> Array :
146+ r"""
147+ Compute the log absolute determinant of the Jacobian of the forward transformation.
148+
149+ For a bijective transformation :math:`f: \mathcal{X} \rightarrow \mathcal{Y}`,
150+ this computes:
151+
152+ .. math::
153+ \log \left| \det \frac{\partial f(x)}{\partial x} \right|
154+
155+ This is essential for computing probability densities under change of variables
156+ and is widely used in normalizing flows and variational inference.
157+
158+ Parameters
159+ ----------
160+ x : array_like
161+ Input in the unconstrained domain.
162+ y : array_like
163+ Output in the constrained domain (i.e., y = forward(x)).
164+ This parameter is provided for efficiency since it may already
165+ be computed.
166+
167+ Returns
168+ -------
169+ Array
170+ Log absolute determinant of the Jacobian.
171+
172+ Notes
173+ -----
174+ The default implementation raises NotImplementedError. Subclasses
175+ should override this method to provide an efficient implementation.
176+
177+ For element-wise transformations, the log determinant is simply
178+ the sum of log absolute derivatives:
179+
180+ .. math::
181+ \log \left| \det J \right| = \sum_i \log \left| \frac{\partial f(x_i)}{\partial x_i} \right|
182+ """
183+ raise NotImplementedError (
184+ f"{ self .__class__ .__name__ } does not implement log_abs_det_jacobian. "
185+ "Override this method in your subclass."
186+ )
187+
139188
140189class Identity (Transform ):
190+ """Identity transformation (no-op)."""
141191 __module__ = 'braintools.param'
142192
143193 def forward (self , x : ArrayLike ) -> Array :
144194 return x
145195
146196 def inverse (self , y : ArrayLike ) -> Array :
147197 return y
198+
199+ def log_abs_det_jacobian (self , x : ArrayLike , y : ArrayLike ) -> Array :
200+ """Log determinant is 0 for identity (det(I) = 1)."""
201+ return jnp .zeros (jnp .shape (x )[:- 1 ] if jnp .ndim (x ) > 0 else ())
0 commit comments