Skip to content

Commit 3afd1df

Browse files
committed
fix: running yapf again with 0.32, earlier using 0.43
1 parent c65d93e commit 3afd1df

File tree

8 files changed

+46
-50
lines changed

8 files changed

+46
-50
lines changed

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
123123
mu_hat = _update_moment(updates, mu, b1, 1)
124124
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
125125
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
126-
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
127-
mu_hat,
128-
nu_hat)
126+
updates = jax.tree_map(
127+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
129128
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
130129

131130
return optax.GradientTransformation(init_fn, update_fn)
@@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):
140139

141140
def _update_moment(updates, moments, decay, order):
142141
"""Compute the exponential moving average of the `order-th` moment."""
143-
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
144-
updates,
145-
moments)
142+
return jax.tree_map(
143+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
146144

147145

148146
def _bias_correction(moment, decay, count):

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
123123
mu_hat = _update_moment(updates, mu, b1, 1)
124124
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
125125
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
126-
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
127-
mu_hat,
128-
nu_hat)
126+
updates = jax.tree_map(
127+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
129128
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
130129

131130
return optax.GradientTransformation(init_fn, update_fn)
@@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):
140139

141140
def _update_moment(updates, moments, decay, order):
142141
"""Compute the exponential moving average of the `order-th` moment."""
143-
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
144-
updates,
145-
moments)
142+
return jax.tree_map(
143+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
146144

147145

148146
def _bias_correction(moment, decay, count):

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,8 @@ def update_fn(updates, state, params=None):
132132
mu_hat = _update_moment(updates, mu, b1, 1)
133133
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
134134
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
135-
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
136-
mu_hat,
137-
nu_hat)
135+
updates = jax.tree_map(
136+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
138137
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
139138

140139
return optax.GradientTransformation(init_fn, update_fn)
@@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple):
149148

150149
def _update_moment(updates, moments, decay, order):
151150
"""Compute the exponential moving average of the `order-th` moment."""
152-
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
153-
updates,
154-
moments)
151+
return jax.tree_map(
152+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
155153

156154

157155
def _bias_correction(moment, decay, count):

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,8 @@ def update_fn(updates, state, params=None):
132132
mu_hat = _update_moment(updates, mu, b1, 1)
133133
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
134134
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
135-
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
136-
mu_hat,
137-
nu_hat)
135+
updates = jax.tree_map(
136+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
138137
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
139138

140139
return optax.GradientTransformation(init_fn, update_fn)
@@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple):
149148

150149
def _update_moment(updates, moments, decay, order):
151150
"""Compute the exponential moving average of the `order-th` moment."""
152-
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
153-
updates,
154-
moments)
151+
return jax.tree_map(
152+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
155153

156154

157155
def _bias_correction(moment, decay, count):

reference_algorithms/paper_baselines/nadamw/jax/submission.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
123123
mu_hat = _update_moment(updates, mu, b1, 1)
124124
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
125125
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
126-
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
127-
mu_hat,
128-
nu_hat)
126+
updates = jax.tree_map(
127+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
129128
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
130129

131130
return optax.GradientTransformation(init_fn, update_fn)
@@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):
140139

141140
def _update_moment(updates, moments, decay, order):
142141
"""Compute the exponential moving average of the `order-th` moment."""
143-
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
144-
updates,
145-
moments)
142+
return jax.tree_map(
143+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
146144

147145

148146
def _bias_correction(moment, decay, count):

reference_algorithms/paper_baselines/sam/jax/submission.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,8 @@ def update_fn(updates, state, grad_fn_params_tuple):
6767
# the noised parameters in the same order as on the original gradients and
6868
# with the same 1e-6 epsilon that is used when clipping the gradients.
6969
updates = dual_vector(updates)
70-
noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u,
71-
params,
72-
updates)
70+
noised_params = jax.tree_util.tree_map(
71+
lambda p, u: p + rho * u, params, updates)
7372
(_, (n_valid_examples, _)), updates = grad_fn(noised_params)
7473
# Get correct global mean grad.
7574
(n_valid_examples, updates) = lax.psum((n_valid_examples, updates),
@@ -81,7 +80,8 @@ def update_fn(updates, state, grad_fn_params_tuple):
8180
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates)))
8281
scaled_updates = jax.tree_map(
8382
lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates)
84-
updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates,
83+
updates = jax.lax.cond(updates_norm > grad_clip,
84+
lambda _: scaled_updates,
8585
lambda _: updates,
8686
None)
8787
updates, state = base_opt_update_fn(updates, state, params)

reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ def matrix_inverse_pth_root(
595595

596596
if padding_start is not None:
597597
# Zero out padding in identity as well for convergence checks.
598-
ix = (jnp.arange(matrix_size, dtype=jnp.int32)
599-
< padding_start).astype(matrix.dtype)
598+
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
599+
matrix.dtype)
600600
matrix *= ix[jnp.newaxis, :]
601601
matrix *= ix[:, jnp.newaxis]
602602
identity *= ix
@@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh(
815815
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
816816
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
817817
if padding_start is not None:
818-
ix = (jnp.arange(matrix_size, dtype=jnp.int32)
819-
< padding_start).astype(matrix.dtype)
818+
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
819+
matrix.dtype)
820820
matrix *= ix[jnp.newaxis, :]
821821
matrix *= ix[:, jnp.newaxis]
822822
identity *= ix
@@ -1809,13 +1809,17 @@ def sharded_update_fn(grads, state, params):
18091809
))
18101810

18111811
new_stats_flat = jax.tree_map(
1812-
lambda g, s, p: _compute_stats(g, s, p, state.count),
1812+
lambda g,
1813+
s,
1814+
p: _compute_stats(g, s, p, state.count),
18131815
grads_flat,
18141816
stats_flat,
18151817
params_flat)
18161818

18171819
outputs = jax.tree_map(
1818-
lambda g, s, p: _transform_grad(g, s, p, state.count),
1820+
lambda g,
1821+
s,
1822+
p: _transform_grad(g, s, p, state.count),
18191823
grads_flat,
18201824
new_stats_flat,
18211825
params_flat)
@@ -1919,8 +1923,8 @@ def _internal_inverse_pth_root_all():
19191923
errors = metrics.inverse_pth_root_errors
19201924
errors = errors.reshape((-1, 1, 1))
19211925
predicate = jnp.logical_or(
1922-
jnp.isnan(errors), errors
1923-
>= inverse_failure_threshold).astype(new_preconditioners.dtype)
1926+
jnp.isnan(errors),
1927+
errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
19241928
# TODO(rohananil): Check for numerical instabilities.
19251929
new_conditional_preconditioners = (
19261930
predicate * global_stats.preconditioners +
@@ -2438,7 +2442,9 @@ def update_fn(grads, state, params):
24382442
stats_grads = treedef.flatten_up_to(grads_custom)
24392443

24402444
new_stats_flat = jax.tree_map(
2441-
lambda g, s, p: _compute_stats(g, s, p, state.count),
2445+
lambda g,
2446+
s,
2447+
p: _compute_stats(g, s, p, state.count),
24422448
stats_grads,
24432449
stats_flat,
24442450
params_flat)
@@ -2447,7 +2453,9 @@ def update_fn(grads, state, params):
24472453
params_flat,
24482454
state.count)
24492455
outputs = jax.tree_map(
2450-
lambda g, s, p: _transform_grad(g, s, p, state.count),
2456+
lambda g,
2457+
s,
2458+
p: _transform_grad(g, s, p, state.count),
24512459
grads_flat,
24522460
new_stats_flat,
24532461
params_flat)

reference_algorithms/target_setting_algorithms/jax_nadamw.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,8 @@ def update_fn(updates, state, params=None):
108108
mu_hat = _update_moment(updates, mu, b1, 1)
109109
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
110110
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
111-
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
112-
mu_hat,
113-
nu_hat)
111+
updates = jax.tree_map(
112+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
114113
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
115114

116115
return optax.GradientTransformation(init_fn, update_fn)
@@ -125,9 +124,8 @@ class ScaleByAdamState(NamedTuple):
125124

126125
def _update_moment(updates, moments, decay, order):
127126
"""Compute the exponential moving average of the `order-th` moment."""
128-
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
129-
updates,
130-
moments)
127+
return jax.tree_map(
128+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
131129

132130

133131
def _bias_correction(moment, decay, count):

0 commit comments

Comments
 (0)