We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 42374ea commit ea15908Copy full SHA for ea15908
python/nutpie/transform_adapter.py
@@ -18,7 +18,6 @@ def make_transform_adapter(
18
import flowjax
19
import flowjax.train
20
import flowjax.flows
21
- from flowjax.bijections import mvscale
22
import optax
23
import traceback
24
from paramax import Parameterize, unwrap
@@ -164,6 +163,8 @@ def make_layer(key, is_last=False):
164
163
flow = flowjax.flows._add_default_permute(coupling, n_dim, key_permute)
165
166
if scale_layer:
+ from flowjax.bijections import mvscale
167
+
168
bijections = list(flow.bijections)
169
bijections.append(mvscale.MvScale4(jnp.ones(n_dim) * 1e-5))
170
# bijections.append(mvscale.MvScale3(jnp.ones(n_dim) * 1e-5))
0 commit comments