99import flowjax .distributions
1010import flowjax .flows
1111import numpy as np
12- from paramax import Parameterize , unwrap
12+ from paramax import Parameterize
1313
1414
1515def _generate_sequences (k , r_vals ):
@@ -35,6 +35,7 @@ def _generate_sequences(k, r_vals):
3535 all_sequences .append (sequences )
3636 return np .concatenate (all_sequences , axis = 0 )
3737
38+
3839def _max_run_length (seq ):
3940 """
4041 Given a 1D boolean NumPy array 'seq', compute the maximum run length of consecutive
@@ -68,6 +69,7 @@ def _max_run_length(seq):
6869 run_lengths = np .diff (boundaries )
6970 return int (run_lengths .max ())
7071
72+
7173def _filter_sequences (sequences , m ):
7274 """
7375 Filter a 2D NumPy boolean array 'sequences' (each row a binary sequence) so that
@@ -103,7 +105,9 @@ def _generate_permutations(rng, n_dim, n_layers, max_run=3):
103105 all_sequences = _generate_sequences (n_layers , r )
104106 valid_sequences = _filter_sequences (all_sequences , max_run )
105107
106- valid_sequences = np .repeat (valid_sequences , n_dim // len (valid_sequences ) + 1 , axis = 0 )
108+ valid_sequences = np .repeat (
109+ valid_sequences , n_dim // len (valid_sequences ) + 1 , axis = 0
110+ )
107111 rng .shuffle (valid_sequences , axis = 0 )
108112 is_in_first = valid_sequences [:n_dim ]
109113 rng = np .random .default_rng (42 )
@@ -149,7 +153,6 @@ def __init__(
149153 Likewise `out_features` can also be a string `"scalar"`, in which case the
150154 output from the layer will have shape `()`.
151155 """
152- #dtype = default_floating_dtype() if dtype is None else dtype
153156 dtype = np .float32 if dtype is None else dtype
154157 wkey , bkey = jax .random .split (key , 2 )
155158 in_features_ = 1 if in_features == "scalar" else in_features
@@ -161,7 +164,9 @@ def __init__(
161164 wshape = (out_features_ , in_features_ )
162165 self .weight = eqx .nn ._misc .default_init (wkey , wshape , dtype , lim )
163166 bshape = (out_features_ ,)
164- self .bias = eqx .nn ._misc .default_init (bkey , bshape , dtype , lim ) if use_bias else None
167+ self .bias = (
168+ eqx .nn ._misc .default_init (bkey , bshape , dtype , lim ) if use_bias else None
169+ )
165170
166171 self .in_features = in_features
167172 self .out_features = out_features
@@ -205,6 +210,7 @@ def __call__(self, x: jax.Array, *, key=None) -> jax.Array:
205210 x = jnp .squeeze (x )
206211 return x
207212
213+
208214class FactoredMLP (eqx .Module , strict = True ):
209215 """Standard Multi-Layer Perceptron; also known as a feed-forward network.
210216
@@ -268,7 +274,6 @@ def __init__(
268274 Likewise `out_size` can also be a string `"scalar"`, in which case the
269275 output from the module will have shape `()`.
270276 """
271- #dtype = default_floating_dtype() if dtype is None else dtype
272277 keys = jax .random .split (key , depth + 1 )
273278 layers = []
274279 if isinstance (width_size , int ):
@@ -290,9 +295,7 @@ def __init__(
290295 layers .append ((U , K ))
291296 else :
292297 k = width_size [0 ]
293- layers .append (
294- Linear (in_size , k , use_bias , dtype = dtype , key = keys [0 ])
295- )
298+ layers .append (Linear (in_size , k , use_bias , dtype = dtype , key = keys [0 ]))
296299 activations .append (eqx .filter_vmap (lambda : activation , axis_size = k )())
297300
298301 for i in range (depth - 1 ):
@@ -331,9 +334,6 @@ def __init__(
331334 # In case `activation` or `final_activation` are learnt, then make a separate
332335 # copy of their weights for every neuron.
333336 self .activation = tuple (activations )
334- #self.activation = eqx.filter_vmap(
335- # eqx.filter_vmap(lambda: activation), axis_size=depth
336- #)()
337337 if out_size == "scalar" :
338338 self .final_activation = final_activation
339339 else :
@@ -344,7 +344,7 @@ def __init__(
344344 self .use_final_bias = use_final_bias
345345
346346 @jax .named_scope ("eqx.nn.MLP" )
347- def __call__ (self , x : jax .Array , * , key = None ) -> jax .Array :
347+ def __call__ (self , x : jax .Array , * , key = None ) -> jax .Array :
348348 """**Arguments:**
349349
350350 - `x`: A JAX array with shape `(in_size,)`. (Or shape `()` if
@@ -382,7 +382,6 @@ def __call__(self, x: jax.Array, *, key = None) -> jax.Array:
382382 return x
383383
384384
385-
386385def make_mvscale (key , n_dim , size , randomize_base = False ):
387386 def make_single_hh (key , idx ):
388387 key1 , key2 = jax .random .split (key )
@@ -399,7 +398,10 @@ def make_single_hh(key, idx):
399398 else :
400399 indices = [val % n_dim for val in range (size )]
401400
402- return bijections .Chain ([make_single_hh (key , idx ) for key , idx in zip (keys , indices )])
401+ return bijections .Chain (
402+ [make_single_hh (key , idx ) for key , idx in zip (keys , indices )]
403+ )
404+
403405
404406def make_hh (key , n_dim , size , randomize_base = False ):
405407 def make_single_hh (key , idx ):
@@ -415,19 +417,16 @@ def make_single_hh(key, idx):
415417 else :
416418 indices = [val % n_dim for val in range (size )]
417419
418- return bijections .Chain ([make_single_hh (key , idx ) for key , idx in zip (keys , indices )])
420+ return bijections .Chain (
421+ [make_single_hh (key , idx ) for key , idx in zip (keys , indices )]
422+ )
423+
419424
420425def make_elemwise_trafo (key , n_dim , * , count = 1 ):
421426 def make_elemwise (key , loc ):
422427 key1 , key2 = jax .random .split (key )
423- scale = Parameterize (
424- lambda x : x + jnp .sqrt (1 + x ** 2 ),
425- jnp .zeros (())
426- )
427- theta = Parameterize (
428- lambda x : x + jnp .sqrt (1 + x ** 2 ),
429- jnp .zeros (())
430- )
428+ scale = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
429+ theta = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
431430
432431 affine = bijections .AsymmetricAffine (
433432 loc ,
@@ -459,6 +458,7 @@ def make(key):
459458 make_affine = eqx .filter_vmap (make , axis_size = n_dim )(keys )
460459 return bijections .Vmap (make_affine , in_axes = eqx .if_array (0 ))
461460
461+
462462def make_elemwise_trafo_ (key , n_dim , * , count = 1 ):
463463 def make_elemwise (key ):
464464 scale = Parameterize (
@@ -497,6 +497,7 @@ def make(key):
497497 make_affine = eqx .filter_vmap (make )(keys )
498498 return bijections .Vmap (make_affine ())
499499
500+
500501def make_coupling (key , dim , n_untransformed , ** kwargs ):
501502 n_transformed = dim - n_untransformed
502503
@@ -510,10 +511,12 @@ def make_coupling(key, dim, n_untransformed, **kwargs):
510511 else :
511512 nn_width = 2 * dim
512513
513- transformer = bijections .Chain ([
514- make_elemwise_trafo (key , n_transformed , count = 3 ),
515- mvscale ,
516- ])
514+ transformer = bijections .Chain (
515+ [
516+ make_elemwise_trafo (key , n_transformed , count = 3 ),
517+ mvscale ,
518+ ]
519+ )
517520
518521 def make_mlp (out_size ):
519522 if isinstance (nn_width , tuple ):
@@ -541,6 +544,7 @@ def make_mlp(out_size):
541544 ** kwargs ,
542545 )
543546
547+
544548def make_flow (
545549 seed ,
546550 positions ,
@@ -601,16 +605,6 @@ def make_flow(
601605 if n_layers == 0 :
602606 return bijections .Chain (flows )
603607
604- scale = Parameterize (
605- lambda x : x + jnp .sqrt (1 + x ** 2 ),
606- jnp .zeros (n_dim ),
607- )
608- affine = eqx .tree_at (
609- where = lambda aff : aff .scale ,
610- pytree = bijections .Affine (jnp .zeros (n_dim ), jnp .ones (n_dim )),
611- replace = scale ,
612- )
613-
614608 def make_layer (key , untransformed_dim : int | None , permutation = None ):
615609 key , key_couple , key_permute , key_hh = jax .random .split (key , 4 )
616610
@@ -625,7 +619,7 @@ def make_layer(key, untransformed_dim: int | None, permutation=None):
625619 n_dim ,
626620 untransformed_dim ,
627621 nn_activation = jax .nn .gelu ,
628- nn_width = nn_width
622+ nn_width = nn_width ,
629623 )
630624
631625 if zero_init :
@@ -646,9 +640,7 @@ def add_default_permute(bijection, dim, key):
646640 if dim == 2 :
647641 outer = bijections .Flip ((dim ,))
648642 else :
649- outer = bijections .Permute (
650- jax .random .permutation (key , jnp .arange (dim ))
651- )
643+ outer = bijections .Permute (jax .random .permutation (key , jnp .arange (dim )))
652644
653645 return bijections .Sandwich (outer , bijection )
654646
@@ -698,6 +690,7 @@ def add_default_permute(bijection, dim, key):
698690
699691 return bijections .Chain ([bijection , * flows ])
700692
693+
701694def extend_flow (
702695 key ,
703696 base ,
@@ -714,6 +707,8 @@ def extend_flow(
714707 dct : bool = False ,
715708 extension_var_trafo_count = 2 ,
716709 verbose : bool = False ,
710+ nn_width = None ,
711+ nn_depth = None ,
717712):
718713 n_draws , n_dim = positions .shape
719714
@@ -871,9 +866,7 @@ def extend_flow(
871866 inner .outer ,
872867 bijections .Chain (
873868 [
874- bijections .Sandwich (
875- bijections .Flip (shape = (n_dim ,)), coupling
876- ),
869+ bijections .Sandwich (bijections .Flip (shape = (n_dim ,)), coupling ),
877870 inner .inner ,
878871 ]
879872 ),
0 commit comments