Skip to content

Commit e6212a1

Browse files
committed
tree function naming convention
1 parent c6ff009 commit e6212a1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

optax_swag/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def sample_swag_diag(key: chex.PRNGKey, state: SWAGDiagState, eps: float = 1e-30
4242

4343
@partial(jax.jit, static_argnames=['rank'])
4444
def sample_swag(key: chex.PRNGKey, state: SWAGState, rank: int, scale: float = 1., eps: float = 1e-30) -> optax.Params:
45-
mean, unflatten_tree = jax.flatten_util.ravel_pytree(state.mean)
45+
mean, tree_unflatten_fn = jax.flatten_util.ravel_pytree(state.mean)
4646
p2, _ = jax.flatten_util.ravel_pytree(state.params2)
4747

4848
std = jnp.sqrt(jnp.clip(p2 - jnp.square(mean), a_min=eps))
@@ -57,4 +57,4 @@ def sample_swag(key: chex.PRNGKey, state: SWAGState, rank: int, scale: float = 1
5757
z2 = jax.random.normal(z2_key, (rank,))
5858
z2_scale = scale / jnp.sqrt(2 * (rank - 1))
5959

60-
return unflatten_tree(mean + z1_scale * std * z1 + z2_scale * jnp.matmul(dparams.T, z2))
60+
return tree_unflatten_fn(mean + z1_scale * std * z1 + z2_scale * jnp.matmul(dparams.T, z2))

0 commit comments

Comments
 (0)