@@ -20,8 +20,9 @@ def update_fn(updates: optax.Updates, state: SWAState,
2020 update_mask = next_step == 0
2121 n = state .n + 1 * update_mask
2222
23- next_params = jax .tree_util .tree_map (lambda u , p : u * update_mask + p , updates , params )
24- next_mean = jax .tree_util .tree_map (lambda mu , np : (n * mu + np ) / (n + 1 ),
23+ next_params = jax .tree_util .tree_map (lambda p , u : jnp .where (update_mask , p + u , p ),
24+ params , updates )
25+ next_mean = jax .tree_util .tree_map (lambda mu , np : jnp .where (update_mask , (n * mu + np ) / (n + 1 ), mu ),
2526 state .mean , next_params )
2627
2728 return updates , SWAState (step = next_step , n = n , mean = next_mean )
@@ -44,10 +45,11 @@ def update_fn(updates: optax.Updates, state: SWAGDiagState,
4445 update_mask = next_step == 0
4546 n = state .n + 1 * update_mask
4647
47- next_params = jax .tree_util .tree_map (lambda u , p : u * update_mask + p , updates , params )
48- next_mean = jax .tree_util .tree_map (lambda mu , np : (n * mu + np ) / (n + 1 ),
48+ next_params = jax .tree_util .tree_map (lambda p , u : jnp .where (update_mask , p + u , p ),
49+ params , updates )
50+ next_mean = jax .tree_util .tree_map (lambda mu , np : jnp .where (update_mask , (n * mu + np ) / (n + 1 ), mu ),
4951 state .mean , next_params )
50- next_params2 = jax .tree_util .tree_map (lambda v , np : ( n * v + jnp .square (np )) / (n + 1 ),
52+ next_params2 = jax .tree_util .tree_map (lambda p2 , np : jnp . where ( update_mask , ( n * p2 + jnp .square (np )) / (n + 1 ), p2 ),
5153 state .params2 , next_params )
5254
5355 return updates , SWAGDiagState (step = next_step , n = n , mean = next_mean , params2 = next_params2 )
@@ -74,13 +76,14 @@ def update_fn(updates: optax.Updates, state: SWAGState,
7476 update_mask = next_step == 0
7577 n = state .n + 1 * update_mask
7678
77- next_params = jax .tree_util .tree_map (lambda u , p : u * update_mask + p , updates , params )
78- next_mean = jax .tree_util .tree_map (lambda mu , np : (n * mu + np ) / (n + 1 ),
79+ next_params = jax .tree_util .tree_map (lambda p , u : jnp .where (update_mask , p + u , p ),
80+ params , updates )
81+ next_mean = jax .tree_util .tree_map (lambda mu , np : jnp .where (update_mask , (n * mu + np ) / (n + 1 ), mu ),
7982 state .mean , next_params )
80- next_params2 = jax .tree_util .tree_map (lambda v , np : ( n * v + jnp .square (np )) / (n + 1 ),
83+ next_params2 = jax .tree_util .tree_map (lambda p2 , np : jnp . where ( update_mask , ( n * p2 + jnp .square (np )) / (n + 1 ), p2 ),
8184 state .params2 , next_params )
82- next_dparams = jax .tree_util .tree_map (lambda np , nmu , dp : jnp .where (update_mask , dp .at [state .c ].set (np - nmu ), dp ),
83- next_params , next_mean , state . dparams )
85+ next_dparams = jax .tree_util .tree_map (lambda dp , np , nmu : jnp .where (update_mask , dp .at [state .c ].set (np - nmu ), dp ),
86+ state . dparams , next_params , next_mean )
8487
8588 c = (state .c + 1 * update_mask ) % rank
8689
0 commit comments