Skip to content

Commit c6ff009

Browse files
committed
using jax.numpy.where for masking
1 parent cc0030c commit c6ff009

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

optax_swag/transform.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ def update_fn(updates: optax.Updates, state: SWAState,
2020
update_mask = next_step == 0
2121
n = state.n + 1 * update_mask
2222

23-
next_params = jax.tree_util.tree_map(lambda u, p: u * update_mask + p, updates, params)
24-
next_mean = jax.tree_util.tree_map(lambda mu, np: (n * mu + np) / (n + 1),
23+
next_params = jax.tree_util.tree_map(lambda p, u: jnp.where(update_mask, p + u, p),
24+
params, updates)
25+
next_mean = jax.tree_util.tree_map(lambda mu, np: jnp.where(update_mask, (n * mu + np) / (n + 1), mu),
2526
state.mean, next_params)
2627

2728
return updates, SWAState(step=next_step, n=n, mean=next_mean)
@@ -44,10 +45,11 @@ def update_fn(updates: optax.Updates, state: SWAGDiagState,
4445
update_mask = next_step == 0
4546
n = state.n + 1 * update_mask
4647

47-
next_params = jax.tree_util.tree_map(lambda u, p: u * update_mask + p, updates, params)
48-
next_mean = jax.tree_util.tree_map(lambda mu, np: (n * mu + np) / (n + 1),
48+
next_params = jax.tree_util.tree_map(lambda p, u: jnp.where(update_mask, p + u, p),
49+
params, updates)
50+
next_mean = jax.tree_util.tree_map(lambda mu, np: jnp.where(update_mask, (n * mu + np) / (n + 1), mu),
4951
state.mean, next_params)
50-
next_params2 = jax.tree_util.tree_map(lambda v, np: (n * v + jnp.square(np)) / (n + 1),
52+
next_params2 = jax.tree_util.tree_map(lambda p2, np: jnp.where(update_mask, (n * p2 + jnp.square(np)) / (n + 1), p2),
5153
state.params2, next_params)
5254

5355
return updates, SWAGDiagState(step=next_step, n=n, mean=next_mean, params2=next_params2)
@@ -74,13 +76,14 @@ def update_fn(updates: optax.Updates, state: SWAGState,
7476
update_mask = next_step == 0
7577
n = state.n + 1 * update_mask
7678

77-
next_params = jax.tree_util.tree_map(lambda u, p: u * update_mask + p, updates, params)
78-
next_mean = jax.tree_util.tree_map(lambda mu, np: (n * mu + np) / (n + 1),
79+
next_params = jax.tree_util.tree_map(lambda p, u: jnp.where(update_mask, p + u, p),
80+
params, updates)
81+
next_mean = jax.tree_util.tree_map(lambda mu, np: jnp.where(update_mask, (n * mu + np) / (n + 1), mu),
7982
state.mean, next_params)
80-
next_params2 = jax.tree_util.tree_map(lambda v, np: (n * v + jnp.square(np)) / (n + 1),
83+
next_params2 = jax.tree_util.tree_map(lambda p2, np: jnp.where(update_mask, (n * p2 + jnp.square(np)) / (n + 1), p2),
8184
state.params2, next_params)
82-
next_dparams = jax.tree_util.tree_map(lambda np, nmu, dp: jnp.where(update_mask, dp.at[state.c].set(np - nmu), dp),
83-
next_params, next_mean, state.dparams)
85+
next_dparams = jax.tree_util.tree_map(lambda dp, np, nmu: jnp.where(update_mask, dp.at[state.c].set(np - nmu), dp),
86+
state.dparams, next_params, next_mean)
8487

8588
c = (state.c + 1 * update_mask) % rank
8689

0 commit comments

Comments
 (0)