@@ -602,60 +602,49 @@ def _create_optimizer(self) -> None:
602602 self ._learning_rate_schedule != options .LearningRateScheduleType .NONE
603603 ) and (self ._decay_steps == 0 ):
604604 raise ValueError (
605- 'When a learning schedule is selected, `decay_steps` '
606- 'must be great than zero.'
605+ 'When a learning rate schedule is selected, `gematria_decay_steps` '
606+ ' must be greater than zero.'
607607 )
608- if self ._learning_rate_schedule == options .LearningRateScheduleType .COSINE :
609- self ._decayed_learning_rate = tf .compat .v1 .train .cosine_decay (
610- ** decay_args
611- )
612- elif (
613- self ._learning_rate_schedule
614- == options .LearningRateScheduleType .EXPONENTIAL
615- ):
616- self ._decayed_learning_rate = tf .compat .v1 .train .exponential_decay (
617- ** decay_args , ** decay_rate_arg
618- )
619- elif (
620- self ._learning_rate_schedule
621- == options .LearningRateScheduleType .INVERSE_TIME
622- ):
623- self ._decayed_learning_rate = tf .compat .v1 .train .inverse_time_decay (
624- ** decay_args , ** decay_rate_arg
625- )
626- elif (
627- self ._learning_rate_schedule
628- == options .LearningRateScheduleType .LINEAR_COSINE
629- ):
630- self ._decayed_learning_rate = tf .compat .v1 .train .linear_cosine_decay (
631- ** decay_args
632- )
633- elif (
634- self ._learning_rate_schedule
635- == options .LearningRateScheduleType .NATURAL_EXP
636- ):
637- self ._decayed_learning_rate = tf .compat .v1 .train .natural_exp_decay (
638- ** decay_args , ** decay_rate_arg
639- )
640- elif (
641- self ._learning_rate_schedule
642- == options .LearningRateScheduleType .NOISY_LINEAR_COSINE
643- ):
644- self ._decayed_learning_rate = (
645- tf .compat .v1 .train .noisy_linear_cosine_decay (** decay_args )
646- )
647- elif (
648- self ._learning_rate_schedule
649- == options .LearningRateScheduleType .POLYNOMIAL
650- ):
651- self ._decayed_learning_rate = tf .compat .v1 .train .polynomial_decay (
652- ** decay_args
653- )
654- else :
655- assert (
656- self ._learning_rate_schedule == options .LearningRateScheduleType .NONE
657- )
658- self ._decayed_learning_rate = self ._learning_rate
608+ match self ._learning_rate_schedule :
609+ case options .LearningRateScheduleType .COSINE :
610+ self ._decayed_learning_rate = tf .compat .v1 .train .cosine_decay (
611+ ** decay_args
612+ )
613+ case options .LearningRateScheduleType .EXPONENTIAL :
614+ self ._decayed_learning_rate = tf .compat .v1 .train .exponential_decay (
615+ ** decay_args , ** decay_rate_arg
616+ )
617+ case options .LearningRateScheduleType .INVERSE_TIME :
618+ self ._decayed_learning_rate = tf .compat .v1 .train .inverse_time_decay (
619+ ** decay_args , ** decay_rate_arg
620+ )
621+ case options .LearningRateScheduleType .LINEAR_COSINE :
622+ self ._decayed_learning_rate = tf .compat .v1 .train .linear_cosine_decay (
623+ ** decay_args
624+ )
625+ case options .LearningRateScheduleType .NATURAL_EXP :
626+ self ._decayed_learning_rate = tf .compat .v1 .train .natural_exp_decay (
627+ ** decay_args , ** decay_rate_arg
628+ )
629+ case options .LearningRateScheduleType .NOISY_LINEAR_COSINE :
630+ self ._decayed_learning_rate = (
631+ tf .compat .v1 .train .noisy_linear_cosine_decay (** decay_args )
632+ )
633+ case options .LearningRateScheduleType .POLYNOMIAL :
634+ self ._decayed_learning_rate = tf .compat .v1 .train .polynomial_decay (
635+ ** decay_args
636+ )
637+ case options .LearningRateScheduleType .COSINE_RESTARTS :
638+ decay_args ['first_decay_steps' ] = decay_args .pop ('decay_steps' )
639+ self ._decayed_learning_rate = tf .compat .v1 .train .cosine_decay_restarts (
640+ ** decay_args
641+ )
642+ case _:
643+ assert (
644+ self ._learning_rate_schedule
645+ == options .LearningRateScheduleType .NONE
646+ )
647+ self ._decayed_learning_rate = self ._learning_rate
659648
660649 if self ._optimizer_type == options .OptimizerType .ADAM :
661650 self ._optimizer = tf .compat .v1 .train .AdamOptimizer (
@@ -1389,7 +1378,14 @@ def train_batch(
13891378 grads_and_vars = zip (grads , variables )
13901379
13911380 # TODO(vbshah): Compute and log the number of steps per second as well.
1392- tf .summary .scalar ('learning_rate' , self ._decayed_learning_rate )
1381+ # NOTE(vbshah): The learning rate schedules under `tf.compat.v1.train`
1382+ # return callables that return the decayed learning rate in eager mode.
1383+ tf .summary .scalar (
1384+ 'learning_rate' ,
1385+ self ._decayed_learning_rate ()
1386+ if callable (self ._decayed_learning_rate )
1387+ else self ._decayed_learning_rate ,
1388+ )
13931389 tf .summary .scalar ('overall_loss' , loss_tensor )
13941390
13951391 # TODO(vbshah): Consider writing delta loss summaries as well.
0 commit comments