Skip to content

Commit 11814d2

Browse files
committed
fix: better initialization of masked flows
1 parent 41fa758 commit 11814d2

File tree

3 files changed

+29
-45
lines changed

3 files changed

+29
-45
lines changed

python/nutpie/normalizing_flow.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,13 +1054,19 @@ def make_transformer():
10541054
key, key1 = jax.random.split(key)
10551055
embed = eqx.nn.Sequential(
10561056
[
1057-
eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32),
1057+
eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32, use_bias=True),
10581058
# Activation(_NN_ACTIVATION),
10591059
# eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32),
10601060
]
10611061
)
10621062
key, key1 = jax.random.split(key)
1063-
embed_back = eqx.nn.Linear(n_deembed, size, key=key1, dtype=jnp.float32)
1063+
embed_back = eqx.nn.Linear(
1064+
n_deembed, size, key=key1, dtype=jnp.float32, use_bias=True
1065+
)
1066+
embed_back = jax.tree_util.tree_map(
1067+
lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x,
1068+
embed_back,
1069+
)
10641070

10651071
rng = np.random.default_rng(42) # TODO
10661072
order, counts = _generate_permutations(rng, dim, n_layers)
@@ -1077,20 +1083,25 @@ def make_mvscale(key, n_dim):
10771083
def make_layer(key, mask, embed, embed_back):
10781084
key1, key2, key3, key4, key5 = jax.random.split(key, 5)
10791085
transformer = make_transformer()
1080-
bias = Add(jax.random.normal(key5, (size,)) * 0.01)
1086+
bias = Add(jax.random.normal(key5, (size,)) * 0.001)
1087+
inner = eqx.nn.MLP(
1088+
n_embed,
1089+
n_deembed,
1090+
width_size=nn_width,
1091+
depth=nn_depth,
1092+
key=key2,
1093+
dtype=jnp.float32,
1094+
activation=_NN_ACTIVATION,
1095+
)
1096+
inner = jax.tree_util.tree_map(
1097+
lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x,
1098+
inner,
1099+
)
10811100

10821101
conditioner = eqx.nn.Sequential(
10831102
[
10841103
embed,
1085-
eqx.nn.MLP(
1086-
n_embed,
1087-
n_deembed,
1088-
width_size=nn_width,
1089-
depth=nn_depth,
1090-
key=key2,
1091-
dtype=jnp.float32,
1092-
activation=_NN_ACTIVATION,
1093-
),
1104+
inner,
10941105
eqx.nn.Sequential(
10951106
[
10961107
embed_back,
@@ -1110,11 +1121,6 @@ def make_layer(key, mask, embed, embed_back):
11101121
nn_depth=nn_depth,
11111122
)
11121123

1113-
coupling = jax.tree_util.tree_map(
1114-
lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x,
1115-
coupling,
1116-
)
1117-
11181124
if mvscale:
11191125
scale = make_mvscale(key4, dim)
11201126
return bijections.Chain([coupling, scale])

python/nutpie/transform_adapter.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -112,24 +112,14 @@ def fit_to_data(
112112

113113
for i in loop:
114114
# Shuffle data
115-
start = time.time()
116115
key, *subkeys = jr.split(key, 3)
117116
train_data = [jr.permutation(subkeys[0], a) for a in train_data]
118117
val_data = [jr.permutation(subkeys[1], a) for a in val_data]
119-
if verbose and i == 0:
120-
print("shuffle timing:", time.time() - start)
121-
122-
start = time.time()
123118

124119
key, subkey = jr.split(key)
125120
batches = get_batches(train_data, batch_size)
126121
batch_losses = []
127122

128-
if verbose and i == 0:
129-
print("batch timing:", time.time() - start)
130-
131-
start = time.time()
132-
133123
if True:
134124
for batch in zip(*batches, strict=True):
135125
key, subkey = jr.split(key)
@@ -156,10 +146,6 @@ def fit_to_data(
156146

157147
losses["train"].append((sum(batch_losses) / len(batch_losses)).item())
158148

159-
if verbose and i == 0:
160-
print("step timing:", time.time() - start)
161-
162-
start = time.time()
163149
# Val epoch
164150
batch_losses = []
165151
for batch in zip(*get_batches(val_data, batch_size), strict=True):
@@ -168,9 +154,6 @@ def fit_to_data(
168154
batch_losses.append(loss_i)
169155
losses["val"].append(sum(batch_losses) / len(batch_losses))
170156

171-
if verbose and i == 0:
172-
print("val timing:", time.time() - start)
173-
174157
loop.set_postfix({k: v[-1] for k, v in losses.items()})
175158
if losses["val"][-1] == min(losses["val"]):
176159
best_params = params
@@ -228,7 +211,7 @@ def inverse_gradient_and_val(bijection, draw, grad, logp):
228211
)
229212
elif isinstance(bijection, bijections.Affine):
230213
draw, logdet = bijection.inverse_and_log_det(draw)
231-
grad = grad * bijection.scale
214+
grad = grad * unwrap(bijection.scale)
232215
return (draw, grad, logp - logdet)
233216
elif isinstance(bijection, bijections.Vmap):
234217

@@ -710,12 +693,9 @@ def update(self, seed, positions, gradients, logps):
710693
)
711694
params, static = eqx.partition(flow, eqx.is_inexact_array)
712695

713-
start = time.time()
714696
new_loss = self._loss_fn(
715697
params, static, positions[-128:], gradients[-128:], logps[-128:]
716698
)
717-
if self._verbose:
718-
print("new loss function time: ", time.time() - start)
719699

720700
if self._verbose:
721701
print(f"Chain {self._chain}: New loss {new_loss}, old loss {old_loss}")
@@ -903,8 +883,8 @@ def make_transform_adapter(
903883
make_optimizer=None,
904884
coupling_type="masked",
905885
mvscale_layer=False,
906-
n_embed=None,
907-
n_deembed=None,
886+
num_project=None,
887+
num_embed=None,
908888
):
909889
if extension_windows is None:
910890
extension_windows = []
@@ -918,8 +898,8 @@ def make_transform_adapter(
918898
dct_layer=dct_layer,
919899
nn_depth=nn_depth,
920900
nn_width=nn_width,
921-
n_embed=n_embed,
922-
n_deembed=n_deembed,
901+
n_embed=num_project,
902+
n_deembed=num_embed,
923903
mvscale=mvscale_layer,
924904
kind=coupling_type,
925905
),

tests/test_pymc.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,7 @@ def test_normalizing_flow_1d(kind):
321321
seed=1,
322322
draws=2000,
323323
)
324-
draws = trace.posterior.x.isel(chain=0)
325-
kstest = stats.ks_1samp(draws, stats.halfnorm.cdf)
326-
assert kstest.pvalue > 0.01
324+
assert float(trace.sample_stats.fisher_distance.mean()) < 0.1
327325

328326

329327
@pytest.mark.pymc

0 commit comments

Comments
 (0)