Skip to content

Commit aa4ae28

Browse files
committed
Add new transformation classes: Positive, Negative, ScaledSigmoid, Power, Ordered, Simplex, and UnitVector; implement log_abs_det_jacobian methods
1 parent 5e77111 commit aa4ae28

File tree

5 files changed

+799
-52
lines changed

5 files changed

+799
-52
lines changed

braintools/param/_base.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class Transform(ABC):
4747
Apply the forward transformation :math:`y = f(x)`
4848
inverse(y)
4949
Apply the inverse transformation :math:`x = f^{-1}(y)`
50+
log_abs_det_jacobian(x, y)
51+
Compute the log absolute determinant of the Jacobian
5052
5153
Notes
5254
-----
@@ -64,6 +66,10 @@ class Transform(ABC):
6466
"""
6567
__module__ = 'braintools.param'
6668

69+
def __repr__(self) -> str:
70+
"""Return a string representation of the transform."""
71+
return f"{self.__class__.__name__}()"
72+
6773
def __call__(self, x: ArrayLike) -> Array:
6874
r"""
6975
Apply the forward transformation to the input.
@@ -136,12 +142,60 @@ def inverse(self, y: ArrayLike) -> Array:
136142
"""
137143
pass
138144

145+
def log_abs_det_jacobian(self, x: ArrayLike, y: ArrayLike) -> Array:
146+
r"""
147+
Compute the log absolute determinant of the Jacobian of the forward transformation.
148+
149+
For a bijective transformation :math:`f: \mathcal{X} \rightarrow \mathcal{Y}`,
150+
this computes:
151+
152+
.. math::
153+
\log \left| \det \frac{\partial f(x)}{\partial x} \right|
154+
155+
This is essential for computing probability densities under change of variables
156+
and is widely used in normalizing flows and variational inference.
157+
158+
Parameters
159+
----------
160+
x : array_like
161+
Input in the unconstrained domain.
162+
y : array_like
163+
Output in the constrained domain (i.e., y = forward(x)).
164+
This parameter is provided for efficiency since it may already
165+
be computed.
166+
167+
Returns
168+
-------
169+
Array
170+
Log absolute determinant of the Jacobian.
171+
172+
Notes
173+
-----
174+
The default implementation raises NotImplementedError. Subclasses
175+
should override this method to provide an efficient implementation.
176+
177+
For element-wise transformations, the log determinant is simply
178+
the sum of log absolute derivatives:
179+
180+
.. math::
181+
\log \left| \det J \right| = \sum_i \log \left| \frac{\partial f(x_i)}{\partial x_i} \right|
182+
"""
183+
raise NotImplementedError(
184+
f"{self.__class__.__name__} does not implement log_abs_det_jacobian. "
185+
"Override this method in your subclass."
186+
)
187+
139188

140189
class Identity(Transform):
190+
"""Identity transformation (no-op)."""
141191
__module__ = 'braintools.param'
142192

143193
def forward(self, x: ArrayLike) -> Array:
144194
return x
145195

146196
def inverse(self, y: ArrayLike) -> Array:
147197
return y
198+
199+
def log_abs_det_jacobian(self, x: ArrayLike, y: ArrayLike) -> Array:
200+
"""Log determinant is 0 for identity (det(I) = 1)."""
201+
return jnp.zeros(jnp.shape(x)[:-1] if jnp.ndim(x) > 0 else ())

braintools/param/_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def __init__(
122122
super().__init__(value)
123123
self.transform = transform
124124

125+
def __repr__(self) -> str:
126+
return f"Param(data={self.data}, transform={repr(self.transform)})"
127+
125128
@property
126129
def data(self):
127130
return self.transform(self.value)

0 commit comments

Comments
 (0)