-
Stan's Constraint Transforms are an elegant way to express interesting constraints on parameters at the point the parameter is declared. The constrained parameter is transformed in a set of unconstrained parameters. I've found myself doing these transformations by hand in JAX. I haven't been able to find anything like this in JAX add-on libraries. Was wondering if there's any work to build these on top of JAX somehow, or if the experts have immediate ideas on how it might be done nicely. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
I think in JAX you only need to implement transforms themselves, instead of both transforms and jacobians. |
Beta Was this translation helpful? Give feedback.
-
I'm not familiar with Stan, but this looks like e.g. wanting your parameter to be a symmetric matrix, so you use an arbitrary matrix as your parameter and then use For an example of this in JAX have a look at spectral normalisation in Equinox. (Which constrains a matrix to have unit spectral radius.) This is done by defining a custom class with |
Beta Was this translation helpful? Give feedback.
-
As @YouJiacheng alludes to, Stan-style constraint transforms do a little bit more than just mapping parameters between a constrained and unconstrained space. They also implicitly define a bijective function between the two spaces that also usually has a tractable log-det-Jacobian because we need to compute the volume change when doing a change-of-variable in Hamiltonian Monte Carlo. If you're interested in all of those aspects (bijective function, computing volume changes), I encourage you to check out TFP-on-JAX, a library that has You can also look at |
Beta Was this translation helpful? Give feedback.
As @YouJiacheng alludes to, Stan-style constraint transforms do a little bit more than just mapping parameters between a constrained and unconstrained space. They also implicitly define a bijective function between the two spaces that also usually has a tractable log-det-Jacobian because we need to compute the volume change when doing a change-of-variable in Hamiltonian Monte Carlo.
If you're interested in all of those aspects (bijective function, computing volume changes), I encourage you to check out TFP-on-JAX, a library that has
Bijector
s, i.e. functions that are invertible and have tractable log-det-Jacobians that are used w/ HMC for sampling parameters in an unconstrained space. For…