Skip to content

Commit 2e347ee

Browse files
committed
style: Reformat code
1 parent db139d6 commit 2e347ee

File tree

2 files changed

+56
-71
lines changed

2 files changed

+56
-71
lines changed

python/nutpie/normalizing_flow.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import flowjax.distributions
1010
import flowjax.flows
1111
import numpy as np
12-
from paramax import Parameterize, unwrap
12+
from paramax import Parameterize
1313

1414

1515
def _generate_sequences(k, r_vals):
@@ -35,6 +35,7 @@ def _generate_sequences(k, r_vals):
3535
all_sequences.append(sequences)
3636
return np.concatenate(all_sequences, axis=0)
3737

38+
3839
def _max_run_length(seq):
3940
"""
4041
Given a 1D boolean NumPy array 'seq', compute the maximum run length of consecutive
@@ -68,6 +69,7 @@ def _max_run_length(seq):
6869
run_lengths = np.diff(boundaries)
6970
return int(run_lengths.max())
7071

72+
7173
def _filter_sequences(sequences, m):
7274
"""
7375
Filter a 2D NumPy boolean array 'sequences' (each row a binary sequence) so that
@@ -103,7 +105,9 @@ def _generate_permutations(rng, n_dim, n_layers, max_run=3):
103105
all_sequences = _generate_sequences(n_layers, r)
104106
valid_sequences = _filter_sequences(all_sequences, max_run)
105107

106-
valid_sequences = np.repeat(valid_sequences, n_dim // len(valid_sequences) + 1, axis=0)
108+
valid_sequences = np.repeat(
109+
valid_sequences, n_dim // len(valid_sequences) + 1, axis=0
110+
)
107111
rng.shuffle(valid_sequences, axis=0)
108112
is_in_first = valid_sequences[:n_dim]
109113
rng = np.random.default_rng(42)
@@ -149,7 +153,6 @@ def __init__(
149153
Likewise `out_features` can also be a string `"scalar"`, in which case the
150154
output from the layer will have shape `()`.
151155
"""
152-
#dtype = default_floating_dtype() if dtype is None else dtype
153156
dtype = np.float32 if dtype is None else dtype
154157
wkey, bkey = jax.random.split(key, 2)
155158
in_features_ = 1 if in_features == "scalar" else in_features
@@ -161,7 +164,9 @@ def __init__(
161164
wshape = (out_features_, in_features_)
162165
self.weight = eqx.nn._misc.default_init(wkey, wshape, dtype, lim)
163166
bshape = (out_features_,)
164-
self.bias = eqx.nn._misc.default_init(bkey, bshape, dtype, lim) if use_bias else None
167+
self.bias = (
168+
eqx.nn._misc.default_init(bkey, bshape, dtype, lim) if use_bias else None
169+
)
165170

166171
self.in_features = in_features
167172
self.out_features = out_features
@@ -205,6 +210,7 @@ def __call__(self, x: jax.Array, *, key=None) -> jax.Array:
205210
x = jnp.squeeze(x)
206211
return x
207212

213+
208214
class FactoredMLP(eqx.Module, strict=True):
209215
"""Standard Multi-Layer Perceptron; also known as a feed-forward network.
210216
@@ -268,7 +274,6 @@ def __init__(
268274
Likewise `out_size` can also be a string `"scalar"`, in which case the
269275
output from the module will have shape `()`.
270276
"""
271-
#dtype = default_floating_dtype() if dtype is None else dtype
272277
keys = jax.random.split(key, depth + 1)
273278
layers = []
274279
if isinstance(width_size, int):
@@ -290,9 +295,7 @@ def __init__(
290295
layers.append((U, K))
291296
else:
292297
k = width_size[0]
293-
layers.append(
294-
Linear(in_size, k, use_bias, dtype=dtype, key=keys[0])
295-
)
298+
layers.append(Linear(in_size, k, use_bias, dtype=dtype, key=keys[0]))
296299
activations.append(eqx.filter_vmap(lambda: activation, axis_size=k)())
297300

298301
for i in range(depth - 1):
@@ -331,9 +334,6 @@ def __init__(
331334
# In case `activation` or `final_activation` are learnt, then make a separate
332335
# copy of their weights for every neuron.
333336
self.activation = tuple(activations)
334-
#self.activation = eqx.filter_vmap(
335-
# eqx.filter_vmap(lambda: activation), axis_size=depth
336-
#)()
337337
if out_size == "scalar":
338338
self.final_activation = final_activation
339339
else:
@@ -344,7 +344,7 @@ def __init__(
344344
self.use_final_bias = use_final_bias
345345

346346
@jax.named_scope("eqx.nn.MLP")
347-
def __call__(self, x: jax.Array, *, key = None) -> jax.Array:
347+
def __call__(self, x: jax.Array, *, key=None) -> jax.Array:
348348
"""**Arguments:**
349349
350350
- `x`: A JAX array with shape `(in_size,)`. (Or shape `()` if
@@ -382,7 +382,6 @@ def __call__(self, x: jax.Array, *, key = None) -> jax.Array:
382382
return x
383383

384384

385-
386385
def make_mvscale(key, n_dim, size, randomize_base=False):
387386
def make_single_hh(key, idx):
388387
key1, key2 = jax.random.split(key)
@@ -399,7 +398,10 @@ def make_single_hh(key, idx):
399398
else:
400399
indices = [val % n_dim for val in range(size)]
401400

402-
return bijections.Chain([make_single_hh(key, idx) for key, idx in zip(keys, indices)])
401+
return bijections.Chain(
402+
[make_single_hh(key, idx) for key, idx in zip(keys, indices)]
403+
)
404+
403405

404406
def make_hh(key, n_dim, size, randomize_base=False):
405407
def make_single_hh(key, idx):
@@ -415,19 +417,16 @@ def make_single_hh(key, idx):
415417
else:
416418
indices = [val % n_dim for val in range(size)]
417419

418-
return bijections.Chain([make_single_hh(key, idx) for key, idx in zip(keys, indices)])
420+
return bijections.Chain(
421+
[make_single_hh(key, idx) for key, idx in zip(keys, indices)]
422+
)
423+
419424

420425
def make_elemwise_trafo(key, n_dim, *, count=1):
421426
def make_elemwise(key, loc):
422427
key1, key2 = jax.random.split(key)
423-
scale = Parameterize(
424-
lambda x: x + jnp.sqrt(1 + x**2),
425-
jnp.zeros(())
426-
)
427-
theta = Parameterize(
428-
lambda x: x + jnp.sqrt(1 + x**2),
429-
jnp.zeros(())
430-
)
428+
scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
429+
theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
431430

432431
affine = bijections.AsymmetricAffine(
433432
loc,
@@ -459,6 +458,7 @@ def make(key):
459458
make_affine = eqx.filter_vmap(make, axis_size=n_dim)(keys)
460459
return bijections.Vmap(make_affine, in_axes=eqx.if_array(0))
461460

461+
462462
def make_elemwise_trafo_(key, n_dim, *, count=1):
463463
def make_elemwise(key):
464464
scale = Parameterize(
@@ -497,6 +497,7 @@ def make(key):
497497
make_affine = eqx.filter_vmap(make)(keys)
498498
return bijections.Vmap(make_affine())
499499

500+
500501
def make_coupling(key, dim, n_untransformed, **kwargs):
501502
n_transformed = dim - n_untransformed
502503

@@ -510,10 +511,12 @@ def make_coupling(key, dim, n_untransformed, **kwargs):
510511
else:
511512
nn_width = 2 * dim
512513

513-
transformer = bijections.Chain([
514-
make_elemwise_trafo(key, n_transformed, count=3),
515-
mvscale,
516-
])
514+
transformer = bijections.Chain(
515+
[
516+
make_elemwise_trafo(key, n_transformed, count=3),
517+
mvscale,
518+
]
519+
)
517520

518521
def make_mlp(out_size):
519522
if isinstance(nn_width, tuple):
@@ -541,6 +544,7 @@ def make_mlp(out_size):
541544
**kwargs,
542545
)
543546

547+
544548
def make_flow(
545549
seed,
546550
positions,
@@ -601,16 +605,6 @@ def make_flow(
601605
if n_layers == 0:
602606
return bijections.Chain(flows)
603607

604-
scale = Parameterize(
605-
lambda x: x + jnp.sqrt(1 + x**2),
606-
jnp.zeros(n_dim),
607-
)
608-
affine = eqx.tree_at(
609-
where=lambda aff: aff.scale,
610-
pytree=bijections.Affine(jnp.zeros(n_dim), jnp.ones(n_dim)),
611-
replace=scale,
612-
)
613-
614608
def make_layer(key, untransformed_dim: int | None, permutation=None):
615609
key, key_couple, key_permute, key_hh = jax.random.split(key, 4)
616610

@@ -625,7 +619,7 @@ def make_layer(key, untransformed_dim: int | None, permutation=None):
625619
n_dim,
626620
untransformed_dim,
627621
nn_activation=jax.nn.gelu,
628-
nn_width=nn_width
622+
nn_width=nn_width,
629623
)
630624

631625
if zero_init:
@@ -646,9 +640,7 @@ def add_default_permute(bijection, dim, key):
646640
if dim == 2:
647641
outer = bijections.Flip((dim,))
648642
else:
649-
outer = bijections.Permute(
650-
jax.random.permutation(key, jnp.arange(dim))
651-
)
643+
outer = bijections.Permute(jax.random.permutation(key, jnp.arange(dim)))
652644

653645
return bijections.Sandwich(outer, bijection)
654646

@@ -698,6 +690,7 @@ def add_default_permute(bijection, dim, key):
698690

699691
return bijections.Chain([bijection, *flows])
700692

693+
701694
def extend_flow(
702695
key,
703696
base,
@@ -714,6 +707,8 @@ def extend_flow(
714707
dct: bool = False,
715708
extension_var_trafo_count=2,
716709
verbose: bool = False,
710+
nn_width=None,
711+
nn_depth=None,
717712
):
718713
n_draws, n_dim = positions.shape
719714

@@ -871,9 +866,7 @@ def extend_flow(
871866
inner.outer,
872867
bijections.Chain(
873868
[
874-
bijections.Sandwich(
875-
bijections.Flip(shape=(n_dim,)), coupling
876-
),
869+
bijections.Sandwich(bijections.Flip(shape=(n_dim,)), coupling),
877870
inner.inner,
878871
]
879872
),

0 commit comments

Comments
 (0)