Skip to content

Commit d711b18

Browse files
authored
Merge pull request #189 from danielward27/reduce_coupling
Reduce coupling
2 parents c006585 + 0b870b6 commit d711b18

File tree

4 files changed

+24
-17
lines changed

4 files changed

+24
-17
lines changed

flowjax/bijections/coupling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import jax.numpy as jnp
1111
from jaxtyping import PRNGKeyArray
1212

13+
from flowjax import wrappers
1314
from flowjax.bijections.bijection import AbstractBijection
1415
from flowjax.bijections.jax_transforms import Vmap
1516
from flowjax.utils import Array, get_ravelled_pytree_constructor
@@ -55,7 +56,11 @@ def __init__(
5556
"Only unconditional transformers with shape () are supported.",
5657
)
5758

58-
constructor, num_params = get_ravelled_pytree_constructor(transformer)
59+
constructor, num_params = get_ravelled_pytree_constructor(
60+
transformer,
61+
filter_spec=eqx.is_inexact_array,
62+
is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
63+
)
5964

6065
self.transformer_constructor = constructor
6166
self.untransformed_dim = untransformed_dim

flowjax/bijections/masked_autoregressive.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import jax.numpy as jnp
1010
from jaxtyping import Array, Int, PRNGKeyArray
1111

12+
from flowjax import wrappers
1213
from flowjax.bijections.bijection import AbstractBijection
1314
from flowjax.bijections.jax_transforms import Vmap
1415
from flowjax.masks import rank_based_mask
1516
from flowjax.utils import get_ravelled_pytree_constructor
16-
from flowjax.wrappers import Parameterize
1717

1818

1919
class MaskedAutoregressive(AbstractBijection):
@@ -58,7 +58,11 @@ def __init__(
5858
"Only unconditional transformers with shape () are supported.",
5959
)
6060

61-
constructor, num_params = get_ravelled_pytree_constructor(transformer)
61+
constructor, num_params = get_ravelled_pytree_constructor(
62+
transformer,
63+
filter_spec=eqx.is_inexact_array,
64+
is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
65+
)
6266

6367
if cond_dim is None:
6468
self.cond_shape = None
@@ -162,7 +166,7 @@ def masked_autoregressive_mlp(
162166
masked_linear = eqx.tree_at(
163167
lambda linear: linear.weight,
164168
linear,
165-
Parameterize(jnp.where, mask, linear.weight, 0),
169+
wrappers.Parameterize(jnp.where, mask, linear.weight, 0),
166170
)
167171
masked_layers.append(masked_linear)
168172

flowjax/utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from jax.flatten_util import ravel_pytree
88
from jaxtyping import Array, ArrayLike
99

10-
import flowjax
11-
1210

1311
def inv_softplus(x: ArrayLike) -> Array:
1412
"""The inverse of the softplus function, checking for positive inputs."""
@@ -70,28 +68,28 @@ def _shapes_to_str(shapes):
7068
return f"{in_shapes_str}->{out_shapes_str}"
7169

7270

73-
def get_ravelled_pytree_constructor(tree, filter_spec=eqx.is_inexact_array) -> tuple:
71+
def get_ravelled_pytree_constructor(
72+
tree,
73+
*args,
74+
**kwargs,
75+
) -> tuple:
7476
"""Get a pytree constructor taking ravelled parameters, and the number of params.
7577
7678
The constructor takes a single argument as input, which is all the bijection
7779
parameters flattened into a single contiguous vector. This is useful when we wish to
7880
parameterize a pytree with a neural neural network. Calling the constructor
79-
at the zero vector will return the initial pytree. Parameters warpped in
80-
``NonTrainable`` are treated as leaves during partitioning.
81+
at the zero vector will return the initial pytree. When using, you may wish to
82+
specify ``NonTrainable`` nodes as leaves, using the ``is_leaf`` argument.
8183
8284
Args:
8385
tree: Pytree to form constructor for.
84-
filter_spec: Filter function to specify parameters. Defaults to
85-
eqx.is_inexact_array.
86+
*args: Arguments passed to ``eqx.partition``.
87+
**kwargs: Key word arguments passed to ``eqx.partition``.
8688
8789
Returns:
8890
tuple: Tuple containing the constructor, and the number of parameters.
8991
"""
90-
params, static = eqx.partition(
91-
tree,
92-
filter_spec,
93-
is_leaf=lambda leaf: isinstance(leaf, flowjax.wrappers.NonTrainable),
94-
)
92+
params, static = eqx.partition(tree, *args, **kwargs)
9593
init, unravel = ravel_pytree(params)
9694

9795
def constructor(ravelled_params: Array):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ license = { file = "LICENSE" }
1616
name = "flowjax"
1717
readme = "README.md"
1818
requires-python = ">=3.10"
19-
version = "15.1.0"
19+
version = "16.0.0"
2020

2121
[project.urls]
2222
repository = "https://github.com/danielward27/flowjax"

0 commit comments

Comments
 (0)