4343from gematria .utils .python import timer
4444import numpy as np
4545import scipy .stats
46- import tensorflow . compat . v1 as tf
46+ import tensorflow as tf
4747import tf_slim .evaluation
4848
4949# The type used for TensorFlow feed_dict objects. The type we use here is
@@ -75,7 +75,7 @@ class AddBasicBlockError(Exception):
7575 """The exception raised when adding a block to batch fails."""
7676
7777
78- class SaveBestCheckpoint (tf .train .SessionRunHook ):
78+ class SaveBestCheckpoint (tf .compat . v1 . train .SessionRunHook ):
7979 """A run hook that saves top N models based on error values."""
8080
8181 def __init__ (
@@ -100,8 +100,8 @@ def __init__(
100100 self ._saver = tf .train .Saver (max_to_keep = max_to_keep , name = 'relative_mae' )
101101 self ._last_eval = math .inf
102102
103- def before_run (self , run_context : ...) -> tf .train .SessionRunArgs :
104- return tf .train .SessionRunArgs (
103+ def before_run (self , run_context : ...) -> tf .compat . v1 . train .SessionRunArgs :
104+ return tf .compat . v1 . train .SessionRunArgs (
105105 {'loss' : self ._error_tensor , 'global_step' : self ._global_step }
106106 )
107107
@@ -328,7 +328,7 @@ def __init__(
328328 )
329329 self ._trained_variable_groups = tuple (trained_variable_groups or ())
330330
331- self ._global_step = tf .train .get_or_create_global_step ()
331+ self ._global_step = tf .compat . v1 . train .get_or_create_global_step ()
332332
333333 self ._decayed_learning_rate = None
334334 self ._loss : Optional [loss_utils .LossComputation ] = None
@@ -604,61 +604,67 @@ def _create_optimizer(self) -> None:
604604 'must be great than zero.'
605605 )
606606 if self ._learning_rate_schedule == options .LearningRateScheduleType .COSINE :
607- self ._decayed_learning_rate = tf .train .cosine_decay (** decay_args )
607+ self ._decayed_learning_rate = tf .compat .v1 .train .cosine_decay (
608+ ** decay_args
609+ )
608610 elif (
609611 self ._learning_rate_schedule
610612 == options .LearningRateScheduleType .EXPONENTIAL
611613 ):
612- self ._decayed_learning_rate = tf .train .exponential_decay (
614+ self ._decayed_learning_rate = tf .compat . v1 . train .exponential_decay (
613615 ** decay_args , ** decay_rate_arg
614616 )
615617 elif (
616618 self ._learning_rate_schedule
617619 == options .LearningRateScheduleType .INVERSE_TIME
618620 ):
619- self ._decayed_learning_rate = tf .train .inverse_time_decay (
621+ self ._decayed_learning_rate = tf .compat . v1 . train .inverse_time_decay (
620622 ** decay_args , ** decay_rate_arg
621623 )
622624 elif (
623625 self ._learning_rate_schedule
624626 == options .LearningRateScheduleType .LINEAR_COSINE
625627 ):
626- self ._decayed_learning_rate = tf .train .linear_cosine_decay (** decay_args )
628+ self ._decayed_learning_rate = tf .compat .v1 .train .linear_cosine_decay (
629+ ** decay_args
630+ )
627631 elif (
628632 self ._learning_rate_schedule
629633 == options .LearningRateScheduleType .NATURAL_EXP
630634 ):
631- self ._decayed_learning_rate = tf .train .natural_exp_decay (
635+ self ._decayed_learning_rate = tf .compat . v1 . train .natural_exp_decay (
632636 ** decay_args , ** decay_rate_arg
633637 )
634638 elif (
635639 self ._learning_rate_schedule
636640 == options .LearningRateScheduleType .NOISY_LINEAR_COSINE
637641 ):
638- self ._decayed_learning_rate = tf . train . noisy_linear_cosine_decay (
639- ** decay_args
642+ self ._decayed_learning_rate = (
643+ tf . compat . v1 . train . noisy_linear_cosine_decay ( ** decay_args )
640644 )
641645 elif (
642646 self ._learning_rate_schedule
643647 == options .LearningRateScheduleType .POLYNOMIAL
644648 ):
645- self ._decayed_learning_rate = tf .train .polynomial_decay (** decay_args )
649+ self ._decayed_learning_rate = tf .compat .v1 .train .polynomial_decay (
650+ ** decay_args
651+ )
646652 else :
647653 assert (
648654 self ._learning_rate_schedule == options .LearningRateScheduleType .NONE
649655 )
650656 self ._decayed_learning_rate = self ._learning_rate
651657
652658 if self ._optimizer_type == options .OptimizerType .ADAM :
653- self ._optimizer = tf .train .AdamOptimizer (
659+ self ._optimizer = tf .compat . v1 . train .AdamOptimizer (
654660 learning_rate = self ._decayed_learning_rate
655661 )
656662 elif self ._optimizer_type == options .OptimizerType .SGD :
657- self ._optimizer = tf .train .GradientDescentOptimizer (
663+ self ._optimizer = tf .compat . v1 . train .GradientDescentOptimizer (
658664 learning_rate = self ._decayed_learning_rate
659665 )
660666 elif self ._optimizer_type == options .OptimizerType .RMSPROP :
661- self ._optimizer = tf .train .RMSPropOptimizer (
667+ self ._optimizer = tf .compat . v1 . train .RMSPropOptimizer (
662668 learning_rate = self ._decayed_learning_rate
663669 )
664670 else :
@@ -681,7 +687,7 @@ def _create_optimizer(self) -> None:
681687
682688 def get_monitored_training_session_hooks (
683689 self ,
684- ) -> Sequence [tf .train .SessionRunHook ]:
690+ ) -> Sequence [tf .compat . v1 . train .SessionRunHook ]:
685691 """Returns hooks for a MonitoredTrainingSession required by the model."""
686692 hooks = []
687693 if isinstance (self ._optimizer , tf .train .SyncReplicasOptimizer ):
@@ -1063,7 +1069,9 @@ def run_continuous_evaluation(
10631069 tf_master : str = '' ,
10641070 eval_interval_seconds : int = 60 ,
10651071 max_num_evaluations : Optional [int ] = None ,
1066- session_hooks : Optional [Sequence [tf .train .SessionRunHook ]] = None ,
1072+ session_hooks : Optional [
1073+ Sequence [tf .compat .v1 .train .SessionRunHook ]
1074+ ] = None ,
10671075 max_blocks_in_batch : Optional [int ] = None ,
10681076 max_instructions_in_batch : Optional [int ] = None ,
10691077 ) -> None :
0 commit comments