1717from paramax .wrappers import AbstractUnwrappable
1818
1919
20- _NN_ACTIVATION = jax .nn .leaky_relu
20+ _NN_ACTIVATION = jax .nn .gelu
21+
2122
2223def _generate_sequences (k , r_vals ):
2324 """
@@ -324,6 +325,7 @@ class AsymmetricAffine(bijections.AbstractBijection):
324325 scale: Scale parameter σ (positive)
325326 theta: Asymmetry parameter θ (positive)
326327 """
328+
327329 shape : tuple [int , ...] = ()
328330 cond_shape : ClassVar [None ] = None
329331 loc : Array
@@ -340,6 +342,7 @@ def __init__(
340342 * (arraylike_to_array (a , dtype = float ) for a in (loc , scale , theta )),
341343 )
342344 self .shape = scale .shape
345+ assert self .shape == ()
343346 self .scale = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
344347 self .theta = Parameterize (lambda x : x + jnp .sqrt (1 + x ** 2 ), jnp .zeros (()))
345348
@@ -348,17 +351,18 @@ def _log_derivative_f(self, x, mu, sigma, theta):
348351 theta = jnp .log (theta )
349352
350353 sinh_theta = jnp .sinh (theta )
351- #sinh_theta = (theta - 1 / theta) / 2
354+ # sinh_theta = (theta - 1 / theta) / 2
352355 cosh_theta = jnp .cosh (theta )
353- #cosh_theta = (theta + 1 / theta) / 2
356+ # cosh_theta = (theta + 1 / theta) / 2
354357 numerator = sinh_theta * x * (abs_x + 2.0 )
355- denominator = (abs_x + 1.0 )** 2
358+ denominator = (abs_x + 1.0 ) ** 2
356359 term = numerator / denominator
357360 dy_dx = sigma * (cosh_theta + term )
358361 return jnp .log (dy_dx )
359362
360- def transform_and_log_det (self , x : ArrayLike , condition : ArrayLike | None = None ) -> tuple [Array , Array ]:
361-
363+ def transform_and_log_det (
364+ self , x : ArrayLike , condition : ArrayLike | None = None
365+ ) -> tuple [Array , Array ]:
362366 def transform (x , mu , sigma , theta ):
363367 weight = (jax .nn .soft_sign (x ) + 1 ) / 2
364368 z = x * sigma
@@ -372,17 +376,22 @@ def transform(x, mu, sigma, theta):
372376 y = transform (x , mu , sigma , theta )
373377 logjac = self ._log_derivative_f (x , mu , sigma , theta )
374378 return y , logjac .sum ()
379+ # y, jac = jax.value_and_grad(transform, argnums=0)(x, mu, sigma, theta)
380+ # return y, jnp.log(jac)
375381
376- def inverse_and_log_det (self , y : ArrayLike , condition : ArrayLike | None = None ) -> tuple [Array , Array ]:
377-
382+ def inverse_and_log_det (
383+ self , y : ArrayLike , condition : ArrayLike | None = None
384+ ) -> tuple [Array , Array ]:
378385 def inverse (y , mu , sigma , theta ):
379386 delta = y - mu
380387 inv_theta = 1 / theta
381388
382389 # Case 1: y >= mu (delta >= 0)
383390 a = sigma * (theta + inv_theta )
384- discriminant_pos = jnp .square (a - 2.0 * delta ) + 16.0 * sigma * theta * delta
385- discriminant_pos = jnp .where (discriminant_pos < 0 , 1. , discriminant_pos )
391+ discriminant_pos = (
392+ jnp .square (a - 2.0 * delta ) + 16.0 * sigma * theta * delta
393+ )
394+ discriminant_pos = jnp .where (discriminant_pos < 0 , 1.0 , discriminant_pos )
386395 sqrt_pos = jnp .sqrt (discriminant_pos )
387396 numerator_pos = 2.0 * delta - a + sqrt_pos
388397 denominator_pos = 4.0 * sigma * theta
@@ -391,8 +400,10 @@ def inverse(y, mu, sigma, theta):
391400 # Case 2: y < mu (delta < 0)
392401 sigma_part = sigma * (1.0 + theta * theta )
393402 term2 = 2.0 * delta * theta
394- inside_sqrt_neg = jnp .square (sigma_part + term2 ) - 16.0 * sigma * delta * theta
395- inside_sqrt_neg = jnp .where (inside_sqrt_neg < 0 , 1. , inside_sqrt_neg )
403+ inside_sqrt_neg = (
404+ jnp .square (sigma_part + term2 ) - 16.0 * sigma * delta * theta
405+ )
406+ inside_sqrt_neg = jnp .where (inside_sqrt_neg < 0 , 1.0 , inside_sqrt_neg )
396407 sqrt_neg = jnp .sqrt (inside_sqrt_neg )
397408 numerator_neg = sigma_part + term2 - sqrt_neg
398409 denominator_neg = 4.0 * sigma
@@ -407,6 +418,8 @@ def inverse(y, mu, sigma, theta):
407418 x = inverse (y , mu , sigma , theta )
408419 logjac = self ._log_derivative_f (x , mu , sigma , theta )
409420 return x , - logjac .sum ()
421+ # x, jac = jax.value_and_grad(inverse, argnums=0)(y, mu, sigma, theta)
422+ # return x, jnp.log(jac)
410423
411424
412425class MvScale (bijections .AbstractBijection ):
@@ -499,7 +512,6 @@ def __init__(
499512 self .requires_vmap = False
500513 conditioner_output_size = num_params
501514
502-
503515 self .transformer_constructor = constructor
504516 self .untransformed_dim = untransformed_dim
505517 self .dim = dim
@@ -509,7 +521,9 @@ def __init__(
509521 if conditioner is None :
510522 conditioner = eqx .nn .MLP (
511523 in_size = (
512- untransformed_dim if cond_dim is None else untransformed_dim + cond_dim
524+ untransformed_dim
525+ if cond_dim is None
526+ else untransformed_dim + cond_dim
513527 ),
514528 out_size = conditioner_output_size ,
515529 width_size = nn_width ,
@@ -542,7 +556,9 @@ def _flat_params_to_transformer(self, params: Array):
542556 if self .requires_vmap :
543557 dim = self .dim - self .untransformed_dim
544558 transformer_params = jnp .reshape (params , (dim , - 1 ))
545- transformer = eqx .filter_vmap (self .transformer_constructor )(transformer_params )
559+ transformer = eqx .filter_vmap (self .transformer_constructor )(
560+ transformer_params
561+ )
546562 return bijections .Vmap (transformer , in_axes = eqx .if_array (0 ))
547563 else :
548564 transformer = self .transformer_constructor (params )
@@ -612,7 +628,7 @@ def make_elemwise(key, loc):
612628 replace = theta ,
613629 )
614630
615- return affine
631+ return bijections . Invert ( affine )
616632
617633 def make (key ):
618634 keys = jax .random .split (key , count + 1 )
@@ -626,11 +642,9 @@ def make(key):
626642 return bijections .Vmap (make_affine , in_axes = eqx .if_array (0 ))
627643
628644
629- def make_coupling (key , dim , n_untransformed , ** kwargs ):
645+ def make_coupling (key , dim , n_untransformed , * , inner_mvscale = False , * *kwargs ):
630646 n_transformed = dim - n_untransformed
631647
632- mvscale = make_mvscale (key , n_transformed , 1 , randomize_base = True )
633-
634648 nn_width = kwargs .get ("nn_width" , None )
635649 nn_depth = kwargs .get ("nn_depth" , None )
636650
@@ -646,12 +660,11 @@ def make_coupling(key, dim, n_untransformed, **kwargs):
646660 else :
647661 nn_depth = len (nn_width )
648662
649- transformer = bijections .Chain (
650- [
651- make_elemwise_trafo (key , n_transformed , count = 3 ),
652- #mvscale,
653- ]
654- )
663+ transformer = make_elemwise_trafo (key , n_transformed , count = 3 )
664+
665+ if inner_mvscale :
666+ mvscale = make_mvscale (key , n_transformed , 1 , randomize_base = True )
667+ transformer = bijections .Chain ([transformer , mvscale ])
655668
656669 def make_mlp (out_size ):
657670 if isinstance (nn_width , tuple ):
0 commit comments