Skip to content

Commit 4d725de

Browse files
committed
🔧 Fix experimental_aggregate_gradients params missing on apply_gradient.
1 parent db692cb commit 4d725de

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorflow_tts/optimizers/adamweightdecay.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ def _decay_weights_op(self, var, learning_rate, apply_state):
118118
)
119119
return tf.no_op()
120120

121-
def apply_gradients(self, grads_and_vars, clip_norm=0.5, name=None):
121+
def apply_gradients(self, grads_and_vars, clip_norm=0.5, **kwargs):
122122
grads, tvars = list(zip(*grads_and_vars))
123123
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
124-
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
124+
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), **kwargs)
125125

126126
def _get_lr(self, var_device, var_dtype, apply_state):
127127
"""Retrieves the learning rate with the given state."""

0 commit comments

Comments
 (0)