Skip to content

Commit 0920e28

Browse files
authored
Add cosine_decay_restarts to available LR schedules. (#313)
* Add `cosine_decay_restarts` to available LR schedules. * Move if-else ladder to match case.
1 parent a2d19d3 commit 0920e28

File tree

2 files changed

+52
-54
lines changed

2 files changed

+52
-54
lines changed

gematria/model/python/model_base.py

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -602,60 +602,49 @@ def _create_optimizer(self) -> None:
602602
self._learning_rate_schedule != options.LearningRateScheduleType.NONE
603603
) and (self._decay_steps == 0):
604604
raise ValueError(
605-
'When a learning schedule is selected, `decay_steps` '
606-
'must be great than zero.'
605+
'When a learning rate schedule is selected, `gematria_decay_steps`'
606+
' must be greater than zero.'
607607
)
608-
if self._learning_rate_schedule == options.LearningRateScheduleType.COSINE:
609-
self._decayed_learning_rate = tf.compat.v1.train.cosine_decay(
610-
**decay_args
611-
)
612-
elif (
613-
self._learning_rate_schedule
614-
== options.LearningRateScheduleType.EXPONENTIAL
615-
):
616-
self._decayed_learning_rate = tf.compat.v1.train.exponential_decay(
617-
**decay_args, **decay_rate_arg
618-
)
619-
elif (
620-
self._learning_rate_schedule
621-
== options.LearningRateScheduleType.INVERSE_TIME
622-
):
623-
self._decayed_learning_rate = tf.compat.v1.train.inverse_time_decay(
624-
**decay_args, **decay_rate_arg
625-
)
626-
elif (
627-
self._learning_rate_schedule
628-
== options.LearningRateScheduleType.LINEAR_COSINE
629-
):
630-
self._decayed_learning_rate = tf.compat.v1.train.linear_cosine_decay(
631-
**decay_args
632-
)
633-
elif (
634-
self._learning_rate_schedule
635-
== options.LearningRateScheduleType.NATURAL_EXP
636-
):
637-
self._decayed_learning_rate = tf.compat.v1.train.natural_exp_decay(
638-
**decay_args, **decay_rate_arg
639-
)
640-
elif (
641-
self._learning_rate_schedule
642-
== options.LearningRateScheduleType.NOISY_LINEAR_COSINE
643-
):
644-
self._decayed_learning_rate = (
645-
tf.compat.v1.train.noisy_linear_cosine_decay(**decay_args)
646-
)
647-
elif (
648-
self._learning_rate_schedule
649-
== options.LearningRateScheduleType.POLYNOMIAL
650-
):
651-
self._decayed_learning_rate = tf.compat.v1.train.polynomial_decay(
652-
**decay_args
653-
)
654-
else:
655-
assert (
656-
self._learning_rate_schedule == options.LearningRateScheduleType.NONE
657-
)
658-
self._decayed_learning_rate = self._learning_rate
608+
match self._learning_rate_schedule:
609+
case options.LearningRateScheduleType.COSINE:
610+
self._decayed_learning_rate = tf.compat.v1.train.cosine_decay(
611+
**decay_args
612+
)
613+
case options.LearningRateScheduleType.EXPONENTIAL:
614+
self._decayed_learning_rate = tf.compat.v1.train.exponential_decay(
615+
**decay_args, **decay_rate_arg
616+
)
617+
case options.LearningRateScheduleType.INVERSE_TIME:
618+
self._decayed_learning_rate = tf.compat.v1.train.inverse_time_decay(
619+
**decay_args, **decay_rate_arg
620+
)
621+
case options.LearningRateScheduleType.LINEAR_COSINE:
622+
self._decayed_learning_rate = tf.compat.v1.train.linear_cosine_decay(
623+
**decay_args
624+
)
625+
case options.LearningRateScheduleType.NATURAL_EXP:
626+
self._decayed_learning_rate = tf.compat.v1.train.natural_exp_decay(
627+
**decay_args, **decay_rate_arg
628+
)
629+
case options.LearningRateScheduleType.NOISY_LINEAR_COSINE:
630+
self._decayed_learning_rate = (
631+
tf.compat.v1.train.noisy_linear_cosine_decay(**decay_args)
632+
)
633+
case options.LearningRateScheduleType.POLYNOMIAL:
634+
self._decayed_learning_rate = tf.compat.v1.train.polynomial_decay(
635+
**decay_args
636+
)
637+
case options.LearningRateScheduleType.COSINE_RESTARTS:
638+
decay_args['first_decay_steps'] = decay_args.pop('decay_steps')
639+
self._decayed_learning_rate = tf.compat.v1.train.cosine_decay_restarts(
640+
**decay_args
641+
)
642+
case _:
643+
assert (
644+
self._learning_rate_schedule
645+
== options.LearningRateScheduleType.NONE
646+
)
647+
self._decayed_learning_rate = self._learning_rate
659648

660649
if self._optimizer_type == options.OptimizerType.ADAM:
661650
self._optimizer = tf.compat.v1.train.AdamOptimizer(
@@ -1389,7 +1378,14 @@ def train_batch(
13891378
grads_and_vars = zip(grads, variables)
13901379

13911380
# TODO(vbshah): Compute and log the number of steps per second as well.
1392-
tf.summary.scalar('learning_rate', self._decayed_learning_rate)
1381+
# NOTE(vbshah): The learning rate schedules under `tf.compat.v1.train`
1382+
# return callables that return the decayed learning rate in eager mode.
1383+
tf.summary.scalar(
1384+
'learning_rate',
1385+
self._decayed_learning_rate()
1386+
if callable(self._decayed_learning_rate)
1387+
else self._decayed_learning_rate,
1388+
)
13931389
tf.summary.scalar('overall_loss', loss_tensor)
13941390

13951391
# TODO(vbshah): Consider writing delta loss summaries as well.

gematria/model/python/options.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class LearningRateScheduleType(enum.Enum):
5656
NATURAL_EXP: Applies natural exponential decay to the initial learning rate.
5757
NOISY_LINEAR_COSINE: Applies noisy linear cosine decay to the learning rate.
5858
POLYNOMIAL: Applies a polynomial decay to the learning rate.
59+
COSINE_RESTARTS: Applies a cosine decay with restarts to the learning rate.
5960
"""
6061

6162
NONE = 1
@@ -66,6 +67,7 @@ class LearningRateScheduleType(enum.Enum):
6667
NATURAL_EXP = 6
6768
NOISY_LINEAR_COSINE = 7
6869
POLYNOMIAL = 8
70+
COSINE_RESTARTS = 9
6971

7072

7173
@enum.unique

0 commit comments

Comments
 (0)